adding tests

This commit is contained in:
Petitminion 2021-07-13 19:58:53 +02:00
parent 706815115a
commit 9b8df0ee7c
2 changed files with 71 additions and 33 deletions

View file

@ -1,7 +1,5 @@
import logging import logging
import os
import re import re
# /!\ The next import have xml vulnerabilities but this shouldn't have security implication in funkwhale # /!\ The next import have xml vulnerabilities but this shouldn't have security implication in funkwhale
# since there are only used to generate xspf file. # since there are only used to generate xspf file.
from xml.etree.ElementTree import Element, SubElement from xml.etree.ElementTree import Element, SubElement
@ -18,7 +16,7 @@ logger = logging.getLogger(__name__)
def clean_namespace_xspf(xspf_file): def clean_namespace_xspf(xspf_file):
""" """
This will delete any namaespace found in the xspf file. It will also delete any encoding info. This will delete any namaespace found in the xspf file. It will also delete any encoding info.
This way xspf file will be compatible with our get_playlist_metadata_from_xspf function. This way xspf file will be compatible with our get_track_id_from_xspf function.
""" """
file = open(xspf_file) file = open(xspf_file)
with file as f: with file as f:
@ -27,13 +25,7 @@ def clean_namespace_xspf(xspf_file):
# This is needed because lxml error : "ValueError: Unicode strings with encoding declaration are # This is needed because lxml error : "ValueError: Unicode strings with encoding declaration are
# not supported. Please use bytes input or XML fragments without declaration." # not supported. Please use bytes input or XML fragments without declaration."
xspf_data = re.sub("'encoding='.'", "", xspf_data) xspf_data = re.sub("'encoding='.'", "", xspf_data)
xspf_file_clean = xspf_file + "2" return xspf_data
if os.path.exists(xspf_file_clean):
os.remove(xspf_file_clean)
file = open(xspf_file_clean, "x")
file.write(xspf_data)
return xspf_file_clean
def get_track_id_from_xspf(xspf_file): def get_track_id_from_xspf(xspf_file):
@ -41,39 +33,26 @@ def get_track_id_from_xspf(xspf_file):
Return a list of funkwhale tracks id from a xspf file. Tracks not found in database are ignored. Return a list of funkwhale tracks id from a xspf file. Tracks not found in database are ignored.
""" """
track_list = [] track_list = []
xspf_file_clean = clean_namespace_xspf(xspf_file) xspf_data_clean = clean_namespace_xspf(xspf_file)
tree = etree.parse(xspf_file_clean) tree = etree.fromstring(xspf_data_clean)
tracks = tree.findall(".//track") tracks = tree.findall(".//track")
total_track_count = 0 added_track_count = 0
added_file_count = 0
for track in tracks: for track in tracks:
track_id = "" track_id = ""
total_track_count = total_track_count + 1
total = str(total_track_count)
# Getting metadata of the xspf file # Getting metadata of the xspf file
try: try:
artist = track.find(".//creator").text artist = track.find(".//creator").text
except Exception as e:
logger.info("Error while parsing Xml file :%s" % e)
try:
title = track.find(".//title").text title = track.find(".//title").text
except Exception as e:
logger.info("Error while parsing Xml file :%s" % e)
try:
album = track.find(".//album").text album = track.find(".//album").text
except Exception as e: except Exception as e:
logger.info("Error while parsing Xml file :%s" % e) logger.info(f"Error while parsing Xml file : {e!r}")
# Finding track id in the db # Finding track id in the db
try: try:
artist_id = Artist.objects.get(name=artist) artist_id = Artist.objects.get(name=artist)
except Exception as e:
logger.info("Error while quering database : %s" % e)
try:
album_id = Album.objects.get(title=album) album_id = Album.objects.get(title=album)
except Exception as e: except Exception as e:
logger.info("Error while quering database :%s" % e) logger.info(f"Error while quering database : {e!r}")
try: try:
track_id = Track.objects.get( track_id = Track.objects.get(
title=title, artist=artist_id.id, album=album_id.id title=title, artist=artist_id.id, album=album_id.id
@ -83,16 +62,16 @@ def get_track_id_from_xspf(xspf_file):
try: try:
track_id = Track.objects.get(title=title, artist=artist_id.id) track_id = Track.objects.get(title=title, artist=artist_id.id)
except Exception as e: except Exception as e:
logger.info("Error while quering database :%s" % e) logger.info(f"Error while quering database : {e!r}")
if track_id: if track_id:
track_list.append(track_id.id) track_list.append(track_id.id)
added_file_count = added_file_count + 1 added_track_count = added_track_count + 1
logger.info( logger.info(
str(total) str(len(tracks))
+ " tracks where found in xspf file. " + " tracks where found in xspf file. "
+ str(added_file_count) + str(added_track_count)
+ "are gonna be added to playlist." + " are gonna be added to playlist."
) )
return track_list return track_list
@ -130,6 +109,35 @@ def generate_xspf_from_playlist(playlist_id):
return prettify(top) return prettify(top)
def generate_xspf_from_tracks_ids(tracks_ids):
"""
This returns a string containing playlist data in xspf format. It's used for test purposes.
"""
top = Element("playlist")
top.set("version", "1")
# top.append(Element.fromstring('version="1"'))
title_xspf = SubElement(top, "title")
title_xspf.text = "An automated generated playlist"
trackList_xspf = SubElement(top, "trackList")
for track_id in tracks_ids:
track = Track.objects.get(id=track_id)
track_xspf = SubElement(trackList_xspf, "track")
location_xspf = SubElement(track_xspf, "location")
location_xspf.text = "https://" + track.domain_name + track.listen_url
title_xspf = SubElement(track_xspf, "title")
title_xspf.text = str(track.title)
creator_xspf = SubElement(track_xspf, "creator")
creator_xspf.text = str(track.artist)
if str(track.album) == "[non-album tracks]":
continue
else:
album_xspf = SubElement(track_xspf, "album")
album_xspf.text = str(track.album)
return prettify(top)
def prettify(elem): def prettify(elem):
""" """
Return a pretty-printed XML string for the Element. Return a pretty-printed XML string for the Element.

View file

@ -0,0 +1,30 @@
import os
from defusedxml import ElementTree as etree
from funkwhale_api.playlists import models, utils
def test_get_track_id_from_xspf(factories, tmp_path):
track1 = factories["music.Track"]()
track2 = factories["music.Track"]()
tracks_ids = [track1.id, track2.id]
xspf_content = utils.generate_xspf_from_tracks_ids(tracks_ids)
f = open("test.xspf", "w")
f.write(xspf_content)
f.close()
xspf_file = "test.xspf"
expected = [track1.id, track2.id]
assert utils.get_track_id_from_xspf(xspf_file) == expected
os.remove("test.xspf")
def test_generate_xspf_from_playlist(factories):
playlist = factories["playlists.PlaylistTrack"]()
xspf_test = utils.generate_xspf_from_playlist(playlist.id)
tree = etree.fromstring(xspf_test)
playlist_factory = models.Playlist.objects.get()
track1 = playlist_factory.playlist_tracks.get(id=1)
track1_name = track1.track
assert playlist_factory.name == tree.findtext("./title")
assert track1_name.title == tree.findtext("./trackList/track/title")