diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 56b4cd86..f86c35c5 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -31,6 +31,7 @@ ReverseOneToOneDescriptor, ) from django.db.models.manager import BaseManager +from django.db.models.query import MAX_GET_RESULTS # type: ignore from django.db.models.query_utils import DeferredAttribute from strawberry import UNSET, relay from strawberry.annotation import StrawberryAnnotation @@ -288,6 +289,24 @@ def qs_hook(qs: models.QuerySet): # type: ignore def qs_hook(qs: models.QuerySet): qs = self.get_queryset(qs, info, **kwargs) + # Don't use qs.get() if the queryset is optimized by prefetching. + # Calling get in that case would disregard the prefetched results, because get implicitly + # adds a limit to the query + if (result_cache := qs._result_cache) is not None: # type: ignore + # mimic behavior of get() + # the queryset is already prefetched, no issue with just using len() + qs_len = len(result_cache) + if qs_len == 0: + raise qs.model.DoesNotExist( + f"{qs.model._meta.object_name} matching query does not exist." + ) + if qs_len != 1: + raise qs.model.MultipleObjectsReturned( + f"get() returned more than one {qs.model._meta.object_name} -- it returned " + f"{qs_len if qs_len < MAX_GET_RESULTS else f'more than {qs_len - 1}'}!" + ) + return result_cache[0] + return qs.get() return qs_hook diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 6d4f7a56..ae62d644 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -6,6 +6,7 @@ from django.db import DEFAULT_DB_ALIAS from django.db.models import Count, QuerySet, Window from django.db.models.functions import RowNumber +from django.db.models.query import MAX_GET_RESULTS # type: ignore from strawberry.types import Info from strawberry.types.arguments import StrawberryArgument from strawberry.types.unset import UNSET, UnsetType @@ -343,6 +344,20 @@ def get_queryset( if self.is_paginated: return queryset + # Add implicit pagination if this field is not a list + # that way when first() / get() is called on the QuerySet it does not cause extra queries + # and we don't prefetch more than necessary + if not pagination and not ( + self.is_list or self.is_paginated or self.is_connection + ): + if self.is_optional: + # first() applies order by pk if not ordered + if not queryset.ordered: + queryset = queryset.order_by("pk") + pagination = OffsetPaginationInput(offset=0, limit=1) + else: + pagination = OffsetPaginationInput(offset=0, limit=MAX_GET_RESULTS) + return self.apply_pagination( queryset, pagination, diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 669ab286..3bc46673 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -159,6 +159,8 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) + first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues") + first_issue_required: "IssueType" = strawberry_django.field(field_name="issues") issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.offset_paginated( field_name="issues", order=IssueOrder, diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 26039d01..eb6f1ad0 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -383,6 +383,8 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + firstIssue: IssueType + firstIssueRequired: IssueType! issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 2b11dd7d..b98209fb 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -173,6 +173,8 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + firstIssue: IssueType + firstIssueRequired: IssueType! issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -201,6 +203,8 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + firstIssue: IssueType + firstIssueRequired: IssueType! issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index f4ff56a1..c2274250 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,12 +6,12 @@ from django.db.models import Prefetch from django.utils import timezone from pytest_mock import MockerFixture -from strawberry.relay import to_base64 -from strawberry.types import ExecutionResult +from strawberry.relay import GlobalID, to_base64 +from strawberry.types import ExecutionResult, get_object_definition import strawberry_django from strawberry_django.optimizer import DjangoOptimizerExtension -from tests.projects.schema import StaffType +from tests.projects.schema import IssueType, MilestoneType, StaffType from . import utils from .projects.faker import ( @@ -1605,3 +1605,171 @@ def test_query_paginated_nested(db, gql_client: GraphQLTestClient): }, ] } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_optional(db, gql_client: GraphQLTestClient): + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue = IssueFactory.create(name="Foo", milestone=milestone1) + issue_id = str( + GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id)) + ) + + milestone_id_1 = str( + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone1.id) + ) + ) + milestone_id_2 = str( + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone2.id) + ) + ) + + query = """\ + query TestQuery($id1: GlobalID!, $id2: GlobalID!) { + a: milestone(id: $id1) { + firstIssue { + id + } + } + b: milestone(id: $id2) { + firstIssue { + id + } + } + } + """ + + with assert_num_queries(4): + res = gql_client.query( + query, variables={"id1": milestone_id_1, "id2": milestone_id_2} + ) + + assert res.errors is None + assert res.data == { + "a": { + "firstIssue": { + "id": issue_id, + }, + }, + "b": { + "firstIssue": None, + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required(db, gql_client: GraphQLTestClient): + milestone = MilestoneFactory.create() + + issue = IssueFactory.create(name="Foo", milestone=milestone) + issue_id = str( + GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id)) + ) + + milestone_id = str( + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone.id) + ) + ) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query(query, variables={"id": milestone_id}) + + assert res.errors is None + assert res.data == { + "milestone": { + "firstIssueRequired": { + "id": issue_id, + }, + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required_missing( + db, gql_client: GraphQLTestClient +): + milestone1 = MilestoneFactory.create() + + milestone_id = str( + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone1.id) + ) + ) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query( + query, variables={"id": milestone_id}, assert_no_errors=False + ) + + assert res.errors is not None + assert res.errors == [ + { + "locations": [{"column": 11, "line": 3}], + "message": "Issue matching query does not exist.", + "path": ["milestone", "firstIssueRequired"], + } + ] + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required_multiple_returned( + db, gql_client: GraphQLTestClient +): + milestone = MilestoneFactory.create() + + milestone_id = str( + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone.id) + ) + ) + IssueFactory.create(name="Foo", milestone=milestone) + IssueFactory.create(name="Bar", milestone=milestone) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query( + query, variables={"id": milestone_id}, assert_no_errors=False + ) + + assert res.errors is not None + assert res.errors == [ + { + "locations": [{"column": 11, "line": 3}], + "message": "get() returned more than one Issue -- it returned 2!", + "path": ["milestone", "firstIssueRequired"], + } + ]