From e6cae73e32375da2e8e97a49083ad31471a70603 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 15 Dec 2024 17:31:58 +0100 Subject: [PATCH 01/12] fix(optimizer): Prevent issuing duplicated queries for certain uses of first() and get() --- strawberry_django/fields/field.py | 23 +++++ tests/projects/schema.py | 2 + tests/test_optimizer.py | 149 +++++++++++++++++++++++++++++- 3 files changed, 171 insertions(+), 3 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 56b4cd86..ba3c9a87 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -282,12 +282,35 @@ def qs_hook(qs: models.QuerySet): # type: ignore def qs_hook(qs: models.QuerySet): # type: ignore qs = self.get_queryset(qs, info, **kwargs) + # Don't use qs.first() if the queryset is optimized by prefetching. + # Calling first in that case would disregard the prefetched results, because first implicitly + # adds a limit to the query + if is_optimized_by_prefetching(qs): + return next(iter(qs), None) return qs.first() else: def qs_hook(qs: models.QuerySet): qs = self.get_queryset(qs, info, **kwargs) + # See comment above about qs.first(), the same applies for get() + if is_optimized_by_prefetching(qs): + # mimic behavior of get() + qs_len = len(qs) # the queryset is already prefetched, no issue with just using len() + if qs_len == 0: + raise qs.model.DoesNotExist( + "%s matching query does not exist." % qs.model._meta.object_name + ) + elif qs_len != 1: + raise qs.model.MultipleObjectsReturned( + "get() returned more than one %s -- it returned %s!" + % ( + qs.model._meta.object_name, + qs_len, + ) + ) + else: + return qs[0] return qs.get() return qs_hook diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 669ab286..b44de907 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -169,6 +169,8 @@ class MilestoneType(relay.Node, Named): filters=IssueFilter, ) ) + first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues") + first_issue_required: "IssueType" = strawberry_django.field(field_name="issues") @strawberry_django.field( prefetch_related=[ diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index f4ff56a1..1142030b 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,12 +6,12 @@ from django.db.models import Prefetch from django.utils import timezone from pytest_mock import MockerFixture -from strawberry.relay import to_base64 -from strawberry.types import ExecutionResult +from strawberry.relay import to_base64, GlobalID +from strawberry.types import ExecutionResult, get_object_definition import strawberry_django from strawberry_django.optimizer import DjangoOptimizerExtension -from tests.projects.schema import StaffType +from tests.projects.schema import StaffType, MilestoneType, IssueType from . import utils from .projects.faker import ( @@ -1605,3 +1605,146 @@ def test_query_paginated_nested(db, gql_client: GraphQLTestClient): }, ] } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_optional( + db, gql_client: GraphQLTestClient +): + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue = IssueFactory.create(name="Foo", milestone=milestone1) + issue_id = str(GlobalID(get_object_definition(IssueType).name, str(issue.id))) + + milestone_id_1 = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id))) + milestone_id_2 = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone2.id))) + + query = """\ + query TestQuery($id1: GlobalID!, $id2: GlobalID!) { + a: milestone(id: $id1) { + firstIssue { + id + } + } + b: milestone(id: $id2) { + firstIssue { + id + } + } + } + """ + + with assert_num_queries(4): + res = gql_client.query(query, variables={"id1": milestone_id_1, "id2": milestone_id_2}) + + assert res.errors is None + assert res.data == { + "a": { + "firstIssue": { + "id": issue_id, + }, + }, + "b": { + "firstIssue": None, + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required( + db, gql_client: GraphQLTestClient +): + milestone = MilestoneFactory.create() + + issue = IssueFactory.create(name="Foo", milestone=milestone) + issue_id = str(GlobalID(get_object_definition(IssueType).name, str(issue.id))) + + milestone_id = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone.id))) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query(query, variables={"id": milestone_id}) + + assert res.errors is None + assert res.data == { + "milestone": { + "firstIssueRequired": { + "id": issue_id, + }, + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required_missing( + db, gql_client: GraphQLTestClient +): + milestone1 = MilestoneFactory.create() + + milestone_id = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id))) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query(query, variables={"id": milestone_id}, assert_no_errors=False) + + assert res.errors is not None + assert res.errors == [ + { + "locations": [{"column": 11, "line": 3}], + "message": "Issue matching query does not exist.", + "path": ["milestone", "firstIssueRequired"], + } + ] + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required_multiple_returned( + db, gql_client: GraphQLTestClient +): + milestone = MilestoneFactory.create() + + milestone_id = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone.id))) + IssueFactory.create(name="Foo", milestone=milestone) + IssueFactory.create(name="Bar", milestone=milestone) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query(query, variables={"id": milestone_id}, assert_no_errors=False) + + assert res.errors is not None + assert res.errors == [ + { + "locations": [{"column": 11, "line": 3}], + "message": "get() returned more than one Issue -- it returned 2!", + "path": ["milestone", "firstIssueRequired"], + } + ] From 04b7744a1c68b439642bdda16116942d1df9ace1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Dec 2024 16:39:10 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry_django/fields/field.py | 12 ++++---- tests/test_optimizer.py | 49 +++++++++++++++++++------------ 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index ba3c9a87..6faebbb7 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -296,12 +296,15 @@ def qs_hook(qs: models.QuerySet): # See comment above about qs.first(), the same applies for get() if is_optimized_by_prefetching(qs): # mimic behavior of get() - qs_len = len(qs) # the queryset is already prefetched, no issue with just using len() + qs_len = len( + qs + ) # the queryset is already prefetched, no issue with just using len() if qs_len == 0: raise qs.model.DoesNotExist( - "%s matching query does not exist." % qs.model._meta.object_name + "%s matching query does not exist." + % qs.model._meta.object_name ) - elif qs_len != 1: + if qs_len != 1: raise qs.model.MultipleObjectsReturned( "get() returned more than one %s -- it returned %s!" % ( @@ -309,8 +312,7 @@ def qs_hook(qs: models.QuerySet): qs_len, ) ) - else: - return qs[0] + return qs[0] return qs.get() return qs_hook diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 1142030b..1d525952 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,12 +6,12 @@ from django.db.models import Prefetch from django.utils import timezone from pytest_mock import MockerFixture -from strawberry.relay import to_base64, GlobalID +from strawberry.relay import GlobalID, to_base64 from strawberry.types import ExecutionResult, get_object_definition import strawberry_django from strawberry_django.optimizer import DjangoOptimizerExtension -from tests.projects.schema import StaffType, MilestoneType, IssueType +from tests.projects.schema import IssueType, MilestoneType, StaffType from . import utils from .projects.faker import ( @@ -1608,18 +1608,19 @@ def test_query_paginated_nested(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) -def test_prefetch_multi_field_single_optional( - db, gql_client: GraphQLTestClient -): - +def test_prefetch_multi_field_single_optional(db, gql_client: GraphQLTestClient): milestone1 = MilestoneFactory.create() milestone2 = MilestoneFactory.create() issue = IssueFactory.create(name="Foo", milestone=milestone1) issue_id = str(GlobalID(get_object_definition(IssueType).name, str(issue.id))) - milestone_id_1 = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id))) - milestone_id_2 = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone2.id))) + milestone_id_1 = str( + GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id)) + ) + milestone_id_2 = str( + GlobalID(get_object_definition(MilestoneType).name, str(milestone2.id)) + ) query = """\ query TestQuery($id1: GlobalID!, $id2: GlobalID!) { @@ -1637,7 +1638,9 @@ def test_prefetch_multi_field_single_optional( """ with assert_num_queries(4): - res = gql_client.query(query, variables={"id1": milestone_id_1, "id2": milestone_id_2}) + res = gql_client.query( + query, variables={"id1": milestone_id_1, "id2": milestone_id_2} + ) assert res.errors is None assert res.data == { @@ -1653,15 +1656,15 @@ def test_prefetch_multi_field_single_optional( @pytest.mark.django_db(transaction=True) -def test_prefetch_multi_field_single_required( - db, gql_client: GraphQLTestClient -): +def test_prefetch_multi_field_single_required(db, gql_client: GraphQLTestClient): milestone = MilestoneFactory.create() issue = IssueFactory.create(name="Foo", milestone=milestone) issue_id = str(GlobalID(get_object_definition(IssueType).name, str(issue.id))) - milestone_id = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone.id))) + milestone_id = str( + GlobalID(get_object_definition(MilestoneType).name, str(milestone.id)) + ) query = """\ query TestQuery($id: GlobalID!) { @@ -1688,11 +1691,13 @@ def test_prefetch_multi_field_single_required( @pytest.mark.django_db(transaction=True) def test_prefetch_multi_field_single_required_missing( - db, gql_client: GraphQLTestClient + db, gql_client: GraphQLTestClient ): milestone1 = MilestoneFactory.create() - milestone_id = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id))) + milestone_id = str( + GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id)) + ) query = """\ query TestQuery($id: GlobalID!) { @@ -1705,7 +1710,9 @@ def test_prefetch_multi_field_single_required_missing( """ with assert_num_queries(2): - res = gql_client.query(query, variables={"id": milestone_id}, assert_no_errors=False) + res = gql_client.query( + query, variables={"id": milestone_id}, assert_no_errors=False + ) assert res.errors is not None assert res.errors == [ @@ -1719,11 +1726,13 @@ def test_prefetch_multi_field_single_required_missing( @pytest.mark.django_db(transaction=True) def test_prefetch_multi_field_single_required_multiple_returned( - db, gql_client: GraphQLTestClient + db, gql_client: GraphQLTestClient ): milestone = MilestoneFactory.create() - milestone_id = str(GlobalID(get_object_definition(MilestoneType).name, str(milestone.id))) + milestone_id = str( + GlobalID(get_object_definition(MilestoneType).name, str(milestone.id)) + ) IssueFactory.create(name="Foo", milestone=milestone) IssueFactory.create(name="Bar", milestone=milestone) @@ -1738,7 +1747,9 @@ def test_prefetch_multi_field_single_required_multiple_returned( """ with assert_num_queries(2): - res = gql_client.query(query, variables={"id": milestone_id}, assert_no_errors=False) + res = gql_client.query( + query, variables={"id": milestone_id}, assert_no_errors=False + ) assert res.errors is not None assert res.errors == [ From 7cc709fb517597177607045282c2d94f16764a7b Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 15 Dec 2024 17:31:58 +0100 Subject: [PATCH 03/12] fix(optimizer): Prevent issuing duplicated queries for certain uses of first() and get() --- strawberry_django/fields/field.py | 23 +++++ tests/projects/schema.py | 2 + tests/test_optimizer.py | 149 +++++++++++++++++++++++++++++- 3 files changed, 171 insertions(+), 3 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 56b4cd86..ba3c9a87 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -282,12 +282,35 @@ def qs_hook(qs: models.QuerySet): # type: ignore def qs_hook(qs: models.QuerySet): # type: ignore qs = self.get_queryset(qs, info, **kwargs) + # Don't use qs.first() if the queryset is optimized by prefetching. + # Calling first in that case would disregard the prefetched results, because first implicitly + # adds a limit to the query + if is_optimized_by_prefetching(qs): + return next(iter(qs), None) return qs.first() else: def qs_hook(qs: models.QuerySet): qs = self.get_queryset(qs, info, **kwargs) + # See comment above about qs.first(), the same applies for get() + if is_optimized_by_prefetching(qs): + # mimic behavior of get() + qs_len = len(qs) # the queryset is already prefetched, no issue with just using len() + if qs_len == 0: + raise qs.model.DoesNotExist( + "%s matching query does not exist." % qs.model._meta.object_name + ) + elif qs_len != 1: + raise qs.model.MultipleObjectsReturned( + "get() returned more than one %s -- it returned %s!" + % ( + qs.model._meta.object_name, + qs_len, + ) + ) + else: + return qs[0] return qs.get() return qs_hook diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 669ab286..b44de907 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -169,6 +169,8 @@ class MilestoneType(relay.Node, Named): filters=IssueFilter, ) ) + first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues") + first_issue_required: "IssueType" = strawberry_django.field(field_name="issues") @strawberry_django.field( prefetch_related=[ diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index f4ff56a1..21d6c770 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,12 +6,12 @@ from django.db.models import Prefetch from django.utils import timezone from pytest_mock import MockerFixture -from strawberry.relay import to_base64 -from strawberry.types import ExecutionResult +from strawberry.relay import to_base64, GlobalID +from strawberry.types import ExecutionResult, get_object_definition import strawberry_django from strawberry_django.optimizer import DjangoOptimizerExtension -from tests.projects.schema import StaffType +from tests.projects.schema import StaffType, MilestoneType, IssueType from . import utils from .projects.faker import ( @@ -1605,3 +1605,146 @@ def test_query_paginated_nested(db, gql_client: GraphQLTestClient): }, ] } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_optional( + db, gql_client: GraphQLTestClient +): + + milestone1 = MilestoneFactory.create() + milestone2 = MilestoneFactory.create() + + issue = IssueFactory.create(name="Foo", milestone=milestone1) + issue_id = str(GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))) + + milestone_id_1 = str(GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone1.id))) + milestone_id_2 = str(GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone2.id))) + + query = """\ + query TestQuery($id1: GlobalID!, $id2: GlobalID!) { + a: milestone(id: $id1) { + firstIssue { + id + } + } + b: milestone(id: $id2) { + firstIssue { + id + } + } + } + """ + + with assert_num_queries(4): + res = gql_client.query(query, variables={"id1": milestone_id_1, "id2": milestone_id_2}) + + assert res.errors is None + assert res.data == { + "a": { + "firstIssue": { + "id": issue_id, + }, + }, + "b": { + "firstIssue": None, + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required( + db, gql_client: GraphQLTestClient +): + milestone = MilestoneFactory.create() + + issue = IssueFactory.create(name="Foo", milestone=milestone) + issue_id = str(GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))) + + milestone_id = str(GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone.id))) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query(query, variables={"id": milestone_id}) + + assert res.errors is None + assert res.data == { + "milestone": { + "firstIssueRequired": { + "id": issue_id, + }, + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required_missing( + db, gql_client: GraphQLTestClient +): + milestone1 = MilestoneFactory.create() + + milestone_id = str(GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone1.id))) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query(query, variables={"id": milestone_id}, assert_no_errors=False) + + assert res.errors is not None + assert res.errors == [ + { + "locations": [{"column": 11, "line": 3}], + "message": "Issue matching query does not exist.", + "path": ["milestone", "firstIssueRequired"], + } + ] + + +@pytest.mark.django_db(transaction=True) +def test_prefetch_multi_field_single_required_multiple_returned( + db, gql_client: GraphQLTestClient +): + milestone = MilestoneFactory.create() + + milestone_id = str(GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone.id))) + IssueFactory.create(name="Foo", milestone=milestone) + IssueFactory.create(name="Bar", milestone=milestone) + + query = """\ + query TestQuery($id: GlobalID!) { + milestone(id: $id) { + firstIssueRequired { + id + } + } + } + """ + + with assert_num_queries(2): + res = gql_client.query(query, variables={"id": milestone_id}, assert_no_errors=False) + + assert res.errors is not None + assert res.errors == [ + { + "locations": [{"column": 11, "line": 3}], + "message": "get() returned more than one Issue -- it returned 2!", + "path": ["milestone", "firstIssueRequired"], + } + ] From ed74cf57ddc856666ffa79f4181b3431fdf0df86 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 15 Dec 2024 17:45:04 +0100 Subject: [PATCH 04/12] fix typing issue --- tests/test_optimizer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 1d525952..dc90ffab 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1613,13 +1613,13 @@ def test_prefetch_multi_field_single_optional(db, gql_client: GraphQLTestClient) milestone2 = MilestoneFactory.create() issue = IssueFactory.create(name="Foo", milestone=milestone1) - issue_id = str(GlobalID(get_object_definition(IssueType).name, str(issue.id))) + issue_id = str(GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))) milestone_id_1 = str( - GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id)) + GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone1.id)) ) milestone_id_2 = str( - GlobalID(get_object_definition(MilestoneType).name, str(milestone2.id)) + GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone2.id)) ) query = """\ @@ -1660,10 +1660,10 @@ def test_prefetch_multi_field_single_required(db, gql_client: GraphQLTestClient) milestone = MilestoneFactory.create() issue = IssueFactory.create(name="Foo", milestone=milestone) - issue_id = str(GlobalID(get_object_definition(IssueType).name, str(issue.id))) + issue_id = str(GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))) milestone_id = str( - GlobalID(get_object_definition(MilestoneType).name, str(milestone.id)) + GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone.id)) ) query = """\ @@ -1696,7 +1696,7 @@ def test_prefetch_multi_field_single_required_missing( milestone1 = MilestoneFactory.create() milestone_id = str( - GlobalID(get_object_definition(MilestoneType).name, str(milestone1.id)) + GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone1.id)) ) query = """\ @@ -1731,7 +1731,7 @@ def test_prefetch_multi_field_single_required_multiple_returned( milestone = MilestoneFactory.create() milestone_id = str( - GlobalID(get_object_definition(MilestoneType).name, str(milestone.id)) + GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone.id)) ) IssueFactory.create(name="Foo", milestone=milestone) IssueFactory.create(name="Bar", milestone=milestone) From 27c92b88ff537ac5a4f3e17177413d3c8bc41314 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Dec 2024 16:45:15 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_optimizer.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index dc90ffab..c2274250 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1613,13 +1613,19 @@ def test_prefetch_multi_field_single_optional(db, gql_client: GraphQLTestClient) milestone2 = MilestoneFactory.create() issue = IssueFactory.create(name="Foo", milestone=milestone1) - issue_id = str(GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))) + issue_id = str( + GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id)) + ) milestone_id_1 = str( - GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone1.id)) + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone1.id) + ) ) milestone_id_2 = str( - GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone2.id)) + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone2.id) + ) ) query = """\ @@ -1660,10 +1666,14 @@ def test_prefetch_multi_field_single_required(db, gql_client: GraphQLTestClient) milestone = MilestoneFactory.create() issue = IssueFactory.create(name="Foo", milestone=milestone) - issue_id = str(GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))) + issue_id = str( + GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id)) + ) milestone_id = str( - GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone.id)) + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone.id) + ) ) query = """\ @@ -1696,7 +1706,9 @@ def test_prefetch_multi_field_single_required_missing( milestone1 = MilestoneFactory.create() milestone_id = str( - GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone1.id)) + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone1.id) + ) ) query = """\ @@ -1731,7 +1743,9 @@ def test_prefetch_multi_field_single_required_multiple_returned( milestone = MilestoneFactory.create() milestone_id = str( - GlobalID(get_object_definition(MilestoneType, strict=True).name, str(milestone.id)) + GlobalID( + get_object_definition(MilestoneType, strict=True).name, str(milestone.id) + ) ) IssueFactory.create(name="Foo", milestone=milestone) IssueFactory.create(name="Bar", milestone=milestone) From d158018cd20b7733b0e9a2a4d518ddd1753eeaf0 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 15 Dec 2024 17:53:31 +0100 Subject: [PATCH 06/12] fix schema test --- tests/projects/snapshots/schema.gql | 4 +++- tests/projects/snapshots/schema_with_inheritance.gql | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 26039d01..fe965768 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -399,6 +399,8 @@ type MilestoneType implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! + firstIssue: IssueType + firstIssueRequired: IssueType! myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! @@ -969,4 +971,4 @@ input PermDefinition { The permission itself. If this is empty that means that we are checking for any permission for the given app. """ permission: String -} \ No newline at end of file +} diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 2b11dd7d..407c1371 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -189,6 +189,8 @@ type MilestoneType implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! + firstIssue: IssueType + firstIssueRequired: IssueType! myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! @@ -429,4 +431,4 @@ type UserType implements Node { isSuperuser: Boolean! isStaff: Boolean! fullName: String! -} \ No newline at end of file +} From 8e69f752bd4961e26bb9111011ba6e262e2f3987 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 15 Dec 2024 18:05:43 +0100 Subject: [PATCH 07/12] fix schema test --- tests/projects/snapshots/schema.gql | 2 +- tests/projects/snapshots/schema_with_inheritance.gql | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index fe965768..61da21ca 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -971,4 +971,4 @@ input PermDefinition { The permission itself. If this is empty that means that we are checking for any permission for the given app. """ permission: String -} +} \ No newline at end of file diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 407c1371..748c1ffd 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -219,6 +219,8 @@ type MilestoneTypeSubclass implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! + firstIssue: IssueType + firstIssueRequired: IssueType! myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! @@ -431,4 +433,4 @@ type UserType implements Node { isSuperuser: Boolean! isStaff: Boolean! fullName: String! -} +} \ No newline at end of file From d0ce6b6b4ca34aaefcd212b7c8752621965a052b Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 15 Dec 2024 18:23:56 +0100 Subject: [PATCH 08/12] fix error message formatting --- strawberry_django/fields/field.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 6faebbb7..847b1611 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -301,16 +301,11 @@ def qs_hook(qs: models.QuerySet): ) # the queryset is already prefetched, no issue with just using len() if qs_len == 0: raise qs.model.DoesNotExist( - "%s matching query does not exist." - % qs.model._meta.object_name + f"{qs.model._meta.object_name} matching query does not exist." ) if qs_len != 1: raise qs.model.MultipleObjectsReturned( - "get() returned more than one %s -- it returned %s!" - % ( - qs.model._meta.object_name, - qs_len, - ) + f"get() returned more than one {qs.model._meta.object_name} -- it returned {qs_len}!" ) return qs[0] return qs.get() From 3f81116f56f622535cd79a58df108196f5661757 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Tue, 17 Dec 2024 16:30:50 +0100 Subject: [PATCH 09/12] Improve prefetching for single/optional fields by not prefetching the whole table --- strawberry_django/fields/field.py | 18 ++++++++---------- strawberry_django/pagination.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 847b1611..bd70120e 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -31,6 +31,7 @@ ReverseOneToOneDescriptor, ) from django.db.models.manager import BaseManager +from django.db.models.query import MAX_GET_RESULTS from django.db.models.query_utils import DeferredAttribute from strawberry import UNSET, relay from strawberry.annotation import StrawberryAnnotation @@ -282,30 +283,27 @@ def qs_hook(qs: models.QuerySet): # type: ignore def qs_hook(qs: models.QuerySet): # type: ignore qs = self.get_queryset(qs, info, **kwargs) - # Don't use qs.first() if the queryset is optimized by prefetching. - # Calling first in that case would disregard the prefetched results, because first implicitly - # adds a limit to the query - if is_optimized_by_prefetching(qs): - return next(iter(qs), None) return qs.first() else: def qs_hook(qs: models.QuerySet): qs = self.get_queryset(qs, info, **kwargs) - # See comment above about qs.first(), the same applies for get() + # Don't use qs.get() if the queryset is optimized by prefetching. + # Calling first in that case would disregard the prefetched results, because first implicitly + # adds a limit to the query if is_optimized_by_prefetching(qs): # mimic behavior of get() - qs_len = len( - qs - ) # the queryset is already prefetched, no issue with just using len() + # the queryset is already prefetched, no issue with just using len() + qs_len = len(qs) if qs_len == 0: raise qs.model.DoesNotExist( f"{qs.model._meta.object_name} matching query does not exist." ) if qs_len != 1: raise qs.model.MultipleObjectsReturned( - f"get() returned more than one {qs.model._meta.object_name} -- it returned {qs_len}!" + f"get() returned more than one {qs.model._meta.object_name} -- it returned " + f"{qs_len if qs_len < MAX_GET_RESULTS else f'more than {qs_len - 1}'}!" ) return qs[0] return qs.get() diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 6d4f7a56..31be9dd2 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -6,6 +6,7 @@ from django.db import DEFAULT_DB_ALIAS from django.db.models import Count, QuerySet, Window from django.db.models.functions import RowNumber +from django.db.models.query import MAX_GET_RESULTS from strawberry.types import Info from strawberry.types.arguments import StrawberryArgument from strawberry.types.unset import UNSET, UnsetType @@ -343,6 +344,19 @@ def get_queryset( if self.is_paginated: return queryset + # Add implicit pagination if this field is not a list + # that way when first() / get() is called on the QuerySet it does not cause extra queries + if not pagination and not ( + self.is_list or self.is_paginated or self.is_connection + ): + if self.is_optional: + # first() applies order by pk if not ordered + if not queryset.ordered: + queryset = queryset.order_by("pk") + pagination = OffsetPaginationInput(offset=0, limit=1) + else: + pagination = OffsetPaginationInput(offset=0, limit=MAX_GET_RESULTS) + return self.apply_pagination( queryset, pagination, From 8783f5b45ce2e8784f64a1b223b71dfc2f387b96 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Tue, 17 Dec 2024 16:32:44 +0100 Subject: [PATCH 10/12] Fix typos in comment --- strawberry_django/fields/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index bd70120e..fc5715f7 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -290,7 +290,7 @@ def qs_hook(qs: models.QuerySet): # type: ignore def qs_hook(qs: models.QuerySet): qs = self.get_queryset(qs, info, **kwargs) # Don't use qs.get() if the queryset is optimized by prefetching. - # Calling first in that case would disregard the prefetched results, because first implicitly + # Calling get in that case would disregard the prefetched results, because get implicitly # adds a limit to the query if is_optimized_by_prefetching(qs): # mimic behavior of get() From 51a629122977abd0a06c393c68a6b4bbbff86bb0 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Tue, 17 Dec 2024 16:34:37 +0100 Subject: [PATCH 11/12] Add comment about why we add implicit pagination --- strawberry_django/pagination.py | 1 + 1 file changed, 1 insertion(+) diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 31be9dd2..631034a2 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -346,6 +346,7 @@ def get_queryset( # Add implicit pagination if this field is not a list # that way when first() / get() is called on the QuerySet it does not cause extra queries + # and we don't prefetch more than necessary if not pagination and not ( self.is_list or self.is_paginated or self.is_connection ): From 913c5f6e32b689476512516d5e0bd88f97e58f5c Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Wed, 18 Dec 2024 19:04:02 +0100 Subject: [PATCH 12/12] refactor: Use qs._result_cache instead of `len()` for safety reasons --- strawberry_django/fields/field.py | 9 +++++---- strawberry_django/pagination.py | 2 +- tests/projects/schema.py | 4 ++-- tests/projects/snapshots/schema.gql | 4 ++-- tests/projects/snapshots/schema_with_inheritance.gql | 8 ++++---- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index fc5715f7..f86c35c5 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -31,7 +31,7 @@ ReverseOneToOneDescriptor, ) from django.db.models.manager import BaseManager -from django.db.models.query import MAX_GET_RESULTS +from django.db.models.query import MAX_GET_RESULTS # type: ignore from django.db.models.query_utils import DeferredAttribute from strawberry import UNSET, relay from strawberry.annotation import StrawberryAnnotation @@ -292,10 +292,10 @@ def qs_hook(qs: models.QuerySet): # Don't use qs.get() if the queryset is optimized by prefetching. # Calling get in that case would disregard the prefetched results, because get implicitly # adds a limit to the query - if is_optimized_by_prefetching(qs): + if (result_cache := qs._result_cache) is not None: # type: ignore # mimic behavior of get() # the queryset is already prefetched, no issue with just using len() - qs_len = len(qs) + qs_len = len(result_cache) if qs_len == 0: raise qs.model.DoesNotExist( f"{qs.model._meta.object_name} matching query does not exist." @@ -305,7 +305,8 @@ def qs_hook(qs: models.QuerySet): f"get() returned more than one {qs.model._meta.object_name} -- it returned " f"{qs_len if qs_len < MAX_GET_RESULTS else f'more than {qs_len - 1}'}!" ) - return qs[0] + return result_cache[0] + return qs.get() return qs_hook diff --git a/strawberry_django/pagination.py b/strawberry_django/pagination.py index 631034a2..ae62d644 100644 --- a/strawberry_django/pagination.py +++ b/strawberry_django/pagination.py @@ -6,7 +6,7 @@ from django.db import DEFAULT_DB_ALIAS from django.db.models import Count, QuerySet, Window from django.db.models.functions import RowNumber -from django.db.models.query import MAX_GET_RESULTS +from django.db.models.query import MAX_GET_RESULTS # type: ignore from strawberry.types import Info from strawberry.types.arguments import StrawberryArgument from strawberry.types.unset import UNSET, UnsetType diff --git a/tests/projects/schema.py b/tests/projects/schema.py index b44de907..3bc46673 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -159,6 +159,8 @@ class MilestoneType(relay.Node, Named): order=IssueOrder, pagination=True, ) + first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues") + first_issue_required: "IssueType" = strawberry_django.field(field_name="issues") issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.offset_paginated( field_name="issues", order=IssueOrder, @@ -169,8 +171,6 @@ class MilestoneType(relay.Node, Named): filters=IssueFilter, ) ) - first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues") - first_issue_required: "IssueType" = strawberry_django.field(field_name="issues") @strawberry_django.field( prefetch_related=[ diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index 61da21ca..eb6f1ad0 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -383,6 +383,8 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + firstIssue: IssueType + firstIssueRequired: IssueType! issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -399,8 +401,6 @@ type MilestoneType implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! - firstIssue: IssueType - firstIssueRequired: IssueType! myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index 748c1ffd..b98209fb 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -173,6 +173,8 @@ type MilestoneType implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + firstIssue: IssueType + firstIssueRequired: IssueType! issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -189,8 +191,6 @@ type MilestoneType implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! - firstIssue: IssueType - firstIssueRequired: IssueType! myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String! @@ -203,6 +203,8 @@ type MilestoneTypeSubclass implements Node & Named { dueDate: Date project: ProjectType! issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]! + firstIssue: IssueType + firstIssueRequired: IssueType! issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated! issuesWithFilters( filters: IssueFilter @@ -219,8 +221,6 @@ type MilestoneTypeSubclass implements Node & Named { """Returns the items in the list that come after the specified cursor.""" last: Int = null ): IssueTypeConnection! - firstIssue: IssueType - firstIssueRequired: IssueType! myIssues: [IssueType!]! myBugsCount: Int! asyncField(value: String!): String!