From 4a85e1431fead3a8f334869672df7e051ab044c7 Mon Sep 17 00:00:00 2001 From: Dhruv Bhanushali Date: Fri, 20 Dec 2024 13:54:58 +0400 Subject: [PATCH] Fetch all add-on instances in one DB request (#5289) --- api/api/models/audio.py | 13 ------ api/api/serializers/audio_serializers.py | 6 +-- api/api/views/audio_views.py | 5 +++ api/api/views/media_views.py | 42 ++++++++++++++++--- api/test/integration/test_dead_link_filter.py | 8 +++- api/test/unit/models/test_audio.py | 25 ----------- api/test/unit/views/test_audio_views.py | 39 +++++++++++++++++ 7 files changed, 89 insertions(+), 49 deletions(-) create mode 100644 api/test/unit/views/test_audio_views.py diff --git a/api/api/models/audio.py b/api/api/models/audio.py index 2a0e9621a16..02b82263de2 100644 --- a/api/api/models/audio.py +++ b/api/api/models/audio.py @@ -224,19 +224,6 @@ def duration_in_s(self): def audio_set(self): return getattr(self, "audioset") - def get_waveform(self) -> list[float]: - """ - Get the waveform if it exists. Return a blank list otherwise. - - :return: the waveform, if it exists; empty list otherwise - """ - - try: - add_on = AudioAddOn.objects.get(audio_identifier=self.identifier) - return add_on.waveform_peaks or [] - except AudioAddOn.DoesNotExist: - return [] - def get_or_create_waveform(self): add_on, _ = AudioAddOn.objects.get_or_create(audio_identifier=self.identifier) diff --git a/api/api/serializers/audio_serializers.py b/api/api/serializers/audio_serializers.py index b65f56c89ce..18b552ccc37 100644 --- a/api/api/serializers/audio_serializers.py +++ b/api/api/serializers/audio_serializers.py @@ -179,9 +179,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def get_peaks(self, obj) -> list[int]: - if isinstance(obj, Hit): - obj = Audio.objects.get(identifier=obj.identifier) - return obj.get_waveform() + audio_addon = self.context.get("addons", {}).get(obj.identifier) + if audio_addon: + return audio_addon.waveform_peaks def to_representation(self, instance): # Get the original representation diff --git a/api/api/views/audio_views.py b/api/api/views/audio_views.py index 0adb95f1d89..fe795f61908 100644 --- a/api/api/views/audio_views.py +++ b/api/api/views/audio_views.py @@ -16,6 +16,7 @@ ) from api.docs.audio_docs import thumbnail as thumbnail_docs from api.models import Audio +from api.models.audio import AudioAddOn from api.serializers.audio_serializers import ( AudioReportRequestSerializer, AudioSearchRequestSerializer, @@ -38,6 +39,7 @@ class AudioViewSet(MediaViewSet): """Viewset for all endpoints pertaining to audio.""" model_class = Audio + addon_model_class = AudioAddOn media_type = AUDIO_TYPE query_serializer_class = AudioSearchRequestSerializer default_index = settings.MEDIA_INDEX_MAPPING[AUDIO_TYPE] @@ -47,6 +49,9 @@ class AudioViewSet(MediaViewSet): def get_queryset(self): return super().get_queryset().select_related("sensitive_audio", "audioset") + def include_addons(self, serializer): + return serializer.validated_data.get("peaks") + # Extra actions async def get_image_proxy_media_info(self) -> image_proxy.MediaInfo: diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index 50e651ee042..d49f386e206 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -14,6 +14,7 @@ from api.controllers import search_controller from api.controllers.elasticsearch.related import related_media from api.models import ContentSource +from api.models.base import OpenLedgerModel from api.models.media import AbstractMedia from api.serializers import media_serializers from api.serializers.source_serializers import SourceSerializer @@ -51,6 +52,7 @@ class MediaViewSet(AsyncViewSetMixin, AsyncAPIView, ReadOnlyModelViewSet): # Populate these in the corresponding subclass model_class: type[AbstractMedia] = None + addon_model_class: type[OpenLedgerModel] = None media_type: MediaType | None = None query_serializer_class = None default_index = None @@ -97,7 +99,11 @@ def _get_request_serializer(self, request): req_serializer.is_valid(raise_exception=True) return req_serializer - def get_db_results(self, results): + def get_db_results( + self, + results, + include_addons=False, + ) -> tuple[list[AbstractMedia], list[OpenLedgerModel]]: """ Map ES hits to ORM model instances. @@ -107,6 +113,7 @@ def get_db_results(self, results): which is both unique and indexed, so it's quite performant. :param results: the list of ES hits + :param include_addons: whether to include add-ons with results :return: the corresponding list of ORM model instances """ @@ -121,7 +128,12 @@ def get_db_results(self, results): for result, hit in zip(results, hits): result.fields_matched = getattr(hit.meta, "highlight", None) - return results + if include_addons and self.addon_model_class: + addons = list(self.addon_model_class.objects.filter(pk__in=identifiers)) + else: + addons = [] + + return (results, addons) # Standard actions @@ -147,6 +159,20 @@ def _validate_source(self, source): detail=f"Invalid source '{source}'. Valid sources are: {valid_string}.", ) + def include_addons(self, serializer): + """ + Whether to include objects of the addon model when mapping hits to + objects of the media model. + + If the media type has an addon model, this method should be overridden + in the subclass to return ``True`` based on serializer input. + + :param serializer: the validated serializer instance + :return: whether to include addon model objects + """ + + return False + def get_media_results( self, request, @@ -188,9 +214,13 @@ def get_media_results( except ValueError as e: raise APIException(getattr(e, "message", str(e))) - serializer_context = search_context | self.get_serializer_context() - - results = self.get_db_results(results) + include_addons = self.include_addons(params) + results, addons = self.get_db_results(results, include_addons) + serializer_context = ( + search_context + | self.get_serializer_context() + | {"addons": {addon.audio_identifier: addon for addon in addons}} + ) serializer = self.get_serializer(results, many=True, context=serializer_context) return self.get_paginated_response(serializer.data) @@ -231,7 +261,7 @@ def related(self, request, identifier=None, *_, **__): serializer_context = self.get_serializer_context() - results = self.get_db_results(results) + results, _ = self.get_db_results(results) serializer = self.get_serializer(results, many=True, context=serializer_context) return self.get_paginated_response(serializer.data) diff --git a/api/test/integration/test_dead_link_filter.py b/api/test/integration/test_dead_link_filter.py index ec7a36efa9c..cd876a0a531 100644 --- a/api/test/integration/test_dead_link_filter.py +++ b/api/test/integration/test_dead_link_filter.py @@ -31,6 +31,10 @@ def get_empty_cached_statuses(_, image_urls): _MAKE_HEAD_REQUESTS_MODULE_PATH = "api.utils.check_dead_links._make_head_requests" +def _mock_get_db_results(results, include_addons=False): + return (results, []) + + def _patch_make_head_requests(): def _make_head_requests(urls, *args, **kwargs): responses = [] @@ -67,7 +71,7 @@ def test_dead_link_filtering(mocked_map, api_client): with patch( "api.views.image_views.ImageViewSet.get_db_results" ) as mock_get_db_result: - mock_get_db_result.side_effect = lambda value: value + mock_get_db_result.side_effect = _mock_get_db_results res_with_dead_links = api_client.get( path, query_params | {"filter_dead": False}, @@ -121,7 +125,7 @@ def test_dead_link_filtering_all_dead_links( with patch( "api.views.image_views.ImageViewSet.get_db_results" ) as mock_get_db_result: - mock_get_db_result.side_effect = lambda value: value + mock_get_db_result.side_effect = _mock_get_db_results with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO): response = api_client.get( path, diff --git a/api/test/unit/models/test_audio.py b/api/test/unit/models/test_audio.py index 3372c6b0e98..80f30bf1b9d 100644 --- a/api/test/unit/models/test_audio.py +++ b/api/test/unit/models/test_audio.py @@ -41,28 +41,3 @@ def test_audio_waveform_caches(generate_peaks_mock, audio_fixture): audio_fixture.delete() assert AudioAddOn.objects.count() == 1 - - -@pytest.mark.django_db -@mock.patch("api.models.audio.AudioAddOn.objects.get") -def test_audio_waveform_sent_when_present(get_mock, audio_fixture): - # When ``AudioAddOn.waveform_peaks`` exists, waveform is filled - peaks = [0, 0.25, 0.5, 0.25, 0.1] - get_mock.return_value = mock.Mock(waveform_peaks=peaks) - assert audio_fixture.get_waveform() == peaks - - -@pytest.mark.django_db -@mock.patch("api.models.audio.AudioAddOn.objects.get") -def test_audio_waveform_blank_when_absent(get_mock, audio_fixture): - # When ``AudioAddOn`` does not exist, waveform is blank - get_mock.side_effect = AudioAddOn.DoesNotExist() - assert audio_fixture.get_waveform() == [] - - -@pytest.mark.django_db -@mock.patch("api.models.audio.AudioAddOn.objects.get") -def test_audio_waveform_blank_when_none(get_mock, audio_fixture): - # When ``AudioAddOn.waveform_peaks`` is None, waveform is blank - get_mock.return_value = mock.Mock(waveform_peaks=None) - assert audio_fixture.get_waveform() == [] diff --git a/api/test/unit/views/test_audio_views.py b/api/test/unit/views/test_audio_views.py new file mode 100644 index 00000000000..ef689b30dc7 --- /dev/null +++ b/api/test/unit/views/test_audio_views.py @@ -0,0 +1,39 @@ +from unittest.mock import MagicMock, patch + +import pytest +import pytest_django.asserts + +from test.factory.models import AudioFactory + + +@pytest.mark.parametrize("peaks, query_count", [(True, 2), (False, 1)]) +@pytest.mark.django_db +def test_peaks_param_determines_addons(api_client, peaks, query_count): + num_results = 20 + + # Since controller returns a list of ``Hit``s, not model instances, we must + # set the ``meta`` param on each of them to match the shape of ``Hit``. + results = AudioFactory.create_batch(size=num_results) + for result in results: + result.meta = None + + controller_ret = ( + results, + 1, # num_pages + num_results, + {}, # search_context + ) + with ( + patch( + "api.views.media_views.search_controller", + query_media=MagicMock(return_value=controller_ret), + ), + patch( + "api.serializers.media_serializers.search_controller", + get_sources=MagicMock(return_value={}), + ), + pytest_django.asserts.assertNumQueries(query_count), + ): + res = api_client.get(f"/v1/audio/?peaks={peaks}") + + assert res.status_code == 200