mirror of
https://code.eliotberriot.com/funkwhale/funkwhale.git
synced 2025-10-03 21:29:16 +02:00
[Experimental] Added a new "Similar" radio based on users history (suggested by @gordon)
This commit is contained in:
parent
602a4c3b29
commit
5ce4cc8379
4 changed files with 95 additions and 4 deletions
|
@ -1,7 +1,7 @@
|
|||
import random
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db.models import Count
|
||||
from django.db import connection
|
||||
from rest_framework import serializers
|
||||
from taggit.models import Tag
|
||||
|
||||
|
@ -43,8 +43,7 @@ class SessionRadio(SimpleRadio):
|
|||
return self.session
|
||||
|
||||
def get_queryset(self, **kwargs):
|
||||
qs = Track.objects.annotate(uploads_count=Count("uploads"))
|
||||
return qs.filter(uploads_count__gt=0)
|
||||
return Track.objects.all()
|
||||
|
||||
def get_queryset_kwargs(self):
|
||||
return {}
|
||||
|
@ -56,6 +55,10 @@ class SessionRadio(SimpleRadio):
|
|||
queryset = self.filter_from_session(queryset)
|
||||
if kwargs.pop("filter_playable", True):
|
||||
queryset = queryset.playable_by(self.session.user.actor)
|
||||
queryset = self.filter_queryset(queryset)
|
||||
return queryset
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
return queryset
|
||||
|
||||
def filter_from_session(self, queryset):
|
||||
|
@ -153,6 +156,74 @@ class TagRadio(RelatedObjectRadio):
|
|||
return qs.filter(tags__in=[self.session.related_object])
|
||||
|
||||
|
||||
def weighted_choice(choices):
|
||||
total = sum(w for c, w in choices)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for c, w in choices:
|
||||
if upto + w >= r:
|
||||
return c
|
||||
upto += w
|
||||
assert False, "Shouldn't get here"
|
||||
|
||||
|
||||
class NextNotFound(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register(name="similar")
|
||||
class SimilarRadio(RelatedObjectRadio):
|
||||
model = Track
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = super().filter_queryset(queryset)
|
||||
seeds = list(
|
||||
self.session.session_tracks.all()
|
||||
.values_list("track_id", flat=True)
|
||||
.order_by("-id")[:3]
|
||||
) + [self.session.related_object.pk]
|
||||
for seed in seeds:
|
||||
try:
|
||||
return queryset.filter(pk=self.find_next_id(queryset, seed))
|
||||
except NextNotFound:
|
||||
continue
|
||||
|
||||
return queryset.none()
|
||||
|
||||
def find_next_id(self, queryset, seed):
|
||||
with connection.cursor() as cursor:
|
||||
query = """
|
||||
SELECT next, count(next) AS c
|
||||
FROM (
|
||||
SELECT
|
||||
track_id,
|
||||
creation_date,
|
||||
LEAD(track_id) OVER (
|
||||
PARTITION by user_id order by creation_date asc
|
||||
) AS next
|
||||
FROM history_listening
|
||||
INNER JOIN users_user ON (users_user.id = user_id)
|
||||
WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' OR user_id = %s
|
||||
ORDER BY creation_date ASC
|
||||
) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC;
|
||||
"""
|
||||
cursor.execute(query, [self.session.user_id, seed, seed])
|
||||
next_candidates = list(cursor.fetchall())
|
||||
|
||||
if not next_candidates:
|
||||
raise NextNotFound()
|
||||
|
||||
matching_tracks = list(
|
||||
queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list(
|
||||
"id", flat=True
|
||||
)
|
||||
)
|
||||
next_candidates = [n for n in next_candidates if n[0] in matching_tracks]
|
||||
if not next_candidates:
|
||||
raise NextNotFound()
|
||||
return weighted_choice(next_candidates)
|
||||
|
||||
|
||||
@registry.register(name="artist")
|
||||
class ArtistRadio(RelatedObjectRadio):
|
||||
model = Artist
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue