Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(optimizer): Prevent issuing duplicated queries for certain uses of first() and get() #675

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ReverseOneToOneDescriptor,
)
from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS # type: ignore
from django.db.models.query_utils import DeferredAttribute
from strawberry import UNSET, relay
from strawberry.annotation import StrawberryAnnotation
Expand Down Expand Up @@ -288,6 +289,24 @@ def qs_hook(qs: models.QuerySet): # type: ignore

def qs_hook(qs: models.QuerySet):
diesieben07 marked this conversation as resolved.
Show resolved Hide resolved
qs = self.get_queryset(qs, info, **kwargs)
# Don't use qs.get() if the queryset is optimized by prefetching.
# Calling get in that case would disregard the prefetched results, because get implicitly
# adds a limit to the query
if (result_cache := qs._result_cache) is not None: # type: ignore
# mimic behavior of get()
# the queryset is already prefetched, no issue with just using len()
qs_len = len(result_cache)
if qs_len == 0:
raise qs.model.DoesNotExist(
f"{qs.model._meta.object_name} matching query does not exist."
)
if qs_len != 1:
raise qs.model.MultipleObjectsReturned(
f"get() returned more than one {qs.model._meta.object_name} -- it returned "
f"{qs_len if qs_len < MAX_GET_RESULTS else f'more than {qs_len - 1}'}!"
)
return result_cache[0]

return qs.get()

return qs_hook
Expand Down
15 changes: 15 additions & 0 deletions strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.db import DEFAULT_DB_ALIAS
from django.db.models import Count, QuerySet, Window
from django.db.models.functions import RowNumber
from django.db.models.query import MAX_GET_RESULTS # type: ignore
from strawberry.types import Info
from strawberry.types.arguments import StrawberryArgument
from strawberry.types.unset import UNSET, UnsetType
Expand Down Expand Up @@ -343,6 +344,20 @@ def get_queryset(
if self.is_paginated:
return queryset

# Add implicit pagination if this field is not a list
# that way when first() / get() is called on the QuerySet it does not cause extra queries
# and we don't prefetch more than necessary
if not pagination and not (
self.is_list or self.is_paginated or self.is_connection
):
if self.is_optional:
# first() applies order by pk if not ordered
if not queryset.ordered:
queryset = queryset.order_by("pk")
pagination = OffsetPaginationInput(offset=0, limit=1)
else:
pagination = OffsetPaginationInput(offset=0, limit=MAX_GET_RESULTS)

return self.apply_pagination(
queryset,
pagination,
Expand Down
2 changes: 2 additions & 0 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class MilestoneType(relay.Node, Named):
order=IssueOrder,
pagination=True,
)
first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues")
first_issue_required: "IssueType" = strawberry_django.field(field_name="issues")
issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.offset_paginated(
field_name="issues",
order=IssueOrder,
Expand Down
2 changes: 2 additions & 0 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ type MilestoneType implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
firstIssue: IssueType
firstIssueRequired: IssueType!
issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated!
issuesWithFilters(
filters: IssueFilter
Expand Down
4 changes: 4 additions & 0 deletions tests/projects/snapshots/schema_with_inheritance.gql
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ type MilestoneType implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
firstIssue: IssueType
firstIssueRequired: IssueType!
issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated!
issuesWithFilters(
filters: IssueFilter
Expand Down Expand Up @@ -201,6 +203,8 @@ type MilestoneTypeSubclass implements Node & Named {
dueDate: Date
project: ProjectType!
issues(filters: IssueFilter, order: IssueOrder, pagination: OffsetPaginationInput): [IssueType!]!
firstIssue: IssueType
firstIssueRequired: IssueType!
issuesPaginated(pagination: OffsetPaginationInput, order: IssueOrder): IssueTypeOffsetPaginated!
issuesWithFilters(
filters: IssueFilter
Expand Down
174 changes: 171 additions & 3 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from django.db.models import Prefetch
from django.utils import timezone
from pytest_mock import MockerFixture
from strawberry.relay import to_base64
from strawberry.types import ExecutionResult
from strawberry.relay import GlobalID, to_base64
from strawberry.types import ExecutionResult, get_object_definition

import strawberry_django
from strawberry_django.optimizer import DjangoOptimizerExtension
from tests.projects.schema import StaffType
from tests.projects.schema import IssueType, MilestoneType, StaffType

from . import utils
from .projects.faker import (
Expand Down Expand Up @@ -1605,3 +1605,171 @@ def test_query_paginated_nested(db, gql_client: GraphQLTestClient):
},
]
}


@pytest.mark.django_db(transaction=True)
def test_prefetch_multi_field_single_optional(db, gql_client: GraphQLTestClient):
milestone1 = MilestoneFactory.create()
milestone2 = MilestoneFactory.create()

issue = IssueFactory.create(name="Foo", milestone=milestone1)
issue_id = str(
GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))
)

milestone_id_1 = str(
GlobalID(
get_object_definition(MilestoneType, strict=True).name, str(milestone1.id)
)
)
milestone_id_2 = str(
GlobalID(
get_object_definition(MilestoneType, strict=True).name, str(milestone2.id)
)
)

query = """\
query TestQuery($id1: GlobalID!, $id2: GlobalID!) {
a: milestone(id: $id1) {
firstIssue {
id
}
}
b: milestone(id: $id2) {
firstIssue {
id
}
}
}
"""

with assert_num_queries(4):
res = gql_client.query(
query, variables={"id1": milestone_id_1, "id2": milestone_id_2}
)

assert res.errors is None
assert res.data == {
"a": {
"firstIssue": {
"id": issue_id,
},
},
"b": {
"firstIssue": None,
},
}


@pytest.mark.django_db(transaction=True)
def test_prefetch_multi_field_single_required(db, gql_client: GraphQLTestClient):
milestone = MilestoneFactory.create()

issue = IssueFactory.create(name="Foo", milestone=milestone)
issue_id = str(
GlobalID(get_object_definition(IssueType, strict=True).name, str(issue.id))
)

milestone_id = str(
GlobalID(
get_object_definition(MilestoneType, strict=True).name, str(milestone.id)
)
)

query = """\
query TestQuery($id: GlobalID!) {
milestone(id: $id) {
firstIssueRequired {
id
}
}
}
"""

with assert_num_queries(2):
res = gql_client.query(query, variables={"id": milestone_id})

assert res.errors is None
assert res.data == {
"milestone": {
"firstIssueRequired": {
"id": issue_id,
},
},
}


@pytest.mark.django_db(transaction=True)
def test_prefetch_multi_field_single_required_missing(
db, gql_client: GraphQLTestClient
):
milestone1 = MilestoneFactory.create()

milestone_id = str(
GlobalID(
get_object_definition(MilestoneType, strict=True).name, str(milestone1.id)
)
)

query = """\
query TestQuery($id: GlobalID!) {
milestone(id: $id) {
firstIssueRequired {
id
}
}
}
"""

with assert_num_queries(2):
res = gql_client.query(
query, variables={"id": milestone_id}, assert_no_errors=False
)

assert res.errors is not None
assert res.errors == [
{
"locations": [{"column": 11, "line": 3}],
"message": "Issue matching query does not exist.",
"path": ["milestone", "firstIssueRequired"],
}
]


@pytest.mark.django_db(transaction=True)
def test_prefetch_multi_field_single_required_multiple_returned(
db, gql_client: GraphQLTestClient
):
milestone = MilestoneFactory.create()

milestone_id = str(
GlobalID(
get_object_definition(MilestoneType, strict=True).name, str(milestone.id)
)
)
IssueFactory.create(name="Foo", milestone=milestone)
IssueFactory.create(name="Bar", milestone=milestone)

query = """\
query TestQuery($id: GlobalID!) {
milestone(id: $id) {
firstIssueRequired {
id
}
}
}
"""

with assert_num_queries(2):
res = gql_client.query(
query, variables={"id": milestone_id}, assert_no_errors=False
)

assert res.errors is not None
assert res.errors == [
{
"locations": [{"column": 11, "line": 3}],
"message": "get() returned more than one Issue -- it returned 2!",
"path": ["milestone", "firstIssueRequired"],
}
]
Loading