diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index b62fcd24..8d818e77 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -1421,6 +1421,11 @@ def __init__( self.enable_nested_relations_prefetch = enable_nested_relations_prefetch self.prefetch_custom_queryset = prefetch_custom_queryset + if enable_nested_relations_prefetch: + from strawberry_django.utils.patches import apply_pagination_fix + + apply_pagination_fix() + def on_execute(self) -> Generator[None]: token = optimizer.set(self) try: diff --git a/strawberry_django/utils/patches.py b/strawberry_django/utils/patches.py index e69de29b..345e940c 100644 --- a/strawberry_django/utils/patches.py +++ b/strawberry_django/utils/patches.py @@ -0,0 +1,78 @@ +import django +from django.db import ( + DEFAULT_DB_ALIAS, + NotSupportedError, + connections, +) +from django.db.models import Q, Window +from django.db.models.fields import related_descriptors +from django.db.models.functions import RowNumber +from django.db.models.lookups import GreaterThan, LessThanOrEqual +from django.db.models.sql import Query +from django.db.models.sql.constants import INNER +from django.db.models.sql.where import AND + + +def apply_pagination_fix(): + """Apply pagination fix for Django 5.1 or older. + + This is based on the fix in this patch, which is going to be included in Django 5.2: + https://code.djangoproject.com/ticket/35677#comment:9 + + If can safely be removed when Django 5.2 is the minimum version we support + """ + if django.VERSION >= (5, 2): + return + + # This is a copy of the function, exactly as it exists on Django 4.2, 5.0 and 5.1 + # (there are no differences in this function between these versions) + def _filter_prefetch_queryset(queryset, field_name, instances): + predicate = Q(**{f"{field_name}__in": instances}) + db = queryset._db or DEFAULT_DB_ALIAS + if queryset.query.is_sliced: + if not connections[db].features.supports_over_clause: + raise NotSupportedError( + "Prefetching from a limited queryset is only supported on backends " + "that support window functions." + ) + low_mark, high_mark = queryset.query.low_mark, queryset.query.high_mark + order_by = [ + expr for expr, _ in queryset.query.get_compiler(using=db).get_order_by() + ] + window = Window(RowNumber(), partition_by=field_name, order_by=order_by) + predicate &= GreaterThan(window, low_mark) + if high_mark is not None: + predicate &= LessThanOrEqual(window, high_mark) + queryset.query.clear_limits() + + # >> ORIGINAL CODE + # return queryset.filter(predicate) # noqa: ERA001 + # << ORIGINAL CODE + # >> PATCHED CODE + queryset.query.add_q(predicate, reuse_all_aliases=True) + return queryset + # << PATCHED CODE + + related_descriptors._filter_prefetch_queryset = _filter_prefetch_queryset # type: ignore + + # This is a copy of the function, exactly as it exists on Django 4.2, 5.0 and 5.1 + # (there are no differences in this function between these versions) + def add_q(self, q_object, reuse_all_aliases=False): + existing_inner = { + a for a in self.alias_map if self.alias_map[a].join_type == INNER + } + # >> ORIGINAL CODE + # clause, _ = self._add_q(q_object, self.used_aliases) # noqa: ERA001 + # << ORIGINAL CODE + # >> PATCHED CODE + if reuse_all_aliases: # noqa: SIM108 + can_reuse = set(self.alias_map) + else: + can_reuse = self.used_aliases + clause, _ = self._add_q(q_object, can_reuse) + # << PATCHED CODE + if clause: + self.where.add(clause, AND) + self.demote_joins(existing_inner) + + Query.add_q = add_q diff --git a/tests/projects/models.py b/tests/projects/models.py index 7ecf82d9..205a3b8c 100644 --- a/tests/projects/models.py +++ b/tests/projects/models.py @@ -115,8 +115,8 @@ class Meta: class Issue(NamedModel): - comments: "RelatedManager[Issue]" - issue_assignees: "RelatedManager[Assignee]" + class Meta: # type: ignore + ordering = ("id",) class Kind(models.TextChoices): """Issue kind options.""" @@ -124,6 +124,9 @@ class Kind(models.TextChoices): BUG = "b", "Bug" FEATURE = "f", "Feature" + comments: "RelatedManager[Issue]" + issue_assignees: "RelatedManager[Assignee]" + id = models.BigAutoField( verbose_name="ID", primary_key=True, @@ -203,6 +206,9 @@ class Meta: class Tag(NamedModel): + class Meta: # type: ignore + ordering = ("id",) + issues: "RelatedManager[Issue]" id = models.BigAutoField( diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 2d245d85..5efef3d1 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -7,13 +7,19 @@ from strawberry.types import ExecutionResult import strawberry_django +from strawberry_django.optimizer import DjangoOptimizerExtension from strawberry_django.pagination import ( OffsetPaginationInput, apply, apply_window_pagination, ) from tests import models, utils -from tests.projects.faker import MilestoneFactory, ProjectFactory +from tests.projects.faker import ( + IssueFactory, + MilestoneFactory, + ProjectFactory, + TagFactory, +) @strawberry_django.type(models.Fruit, pagination=True) @@ -145,3 +151,73 @@ def test_apply_window_pagination_with_no_limites(limit): assert first_fruit.name == "fruit2" assert first_fruit._strawberry_row_number == 3 # type: ignore assert first_fruit._strawberry_total_count == 10 # type: ignore + + +@pytest.mark.django_db(transaction=True) +def test_nested_pagination_m2m(gql_client: utils.GraphQLTestClient): + # Create 2 tags and 3 issues + tags = [TagFactory(name=f"Tag {i + 1}") for i in range(2)] + issues = [IssueFactory(name=f"Issue {i + 1}") for i in range(3)] + # Assign issues 1 and 2 to the 1st tag + # Assign issues 2 and 3 to the 2nd tag + # This means that both tags share the 2nd issue + tags[0].issues.set(issues[:2]) + tags[1].issues.set(issues[1:]) + # Query the tags with their issues + # We expect only 2 database queries if the optimizer is enabled, otherwise 3 (N+1) + with utils.assert_num_queries(3 if DjangoOptimizerExtension.enabled.get() else 6): + result = gql_client.query( + """ + query { + tagConn { + totalCount + edges { + node { + name + issues { + totalCount + edges { + node { + name + } + } + } + } + } + } + } + """ + ) + # Check the results + assert not result.errors + assert result.data == { + "tagConn": { + "totalCount": 2, + "edges": [ + { + "node": { + "name": "Tag 1", + "issues": { + "totalCount": 2, + "edges": [ + {"node": {"name": "Issue 1"}}, + {"node": {"name": "Issue 2"}}, + ], + }, + } + }, + { + "node": { + "name": "Tag 2", + "issues": { + "totalCount": 2, + "edges": [ + {"node": {"name": "Issue 2"}}, + {"node": {"name": "Issue 3"}}, + ], + }, + } + }, + ], + } + }