Skip to content

Commit

Permalink
Improve prefetching for single/optional fields by not prefetching the…
Browse files Browse the repository at this point in the history
… whole table
  • Loading branch information
diesieben07 committed Dec 17, 2024
1 parent d0ce6b6 commit 3f81116
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
18 changes: 8 additions & 10 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ReverseOneToOneDescriptor,
)
from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS

Check failure on line 34 in strawberry_django/fields/field.py

View workflow job for this annotation

GitHub Actions / Typing

"MAX_GET_RESULTS" is unknown import symbol (reportAttributeAccessIssue)
from django.db.models.query_utils import DeferredAttribute
from strawberry import UNSET, relay
from strawberry.annotation import StrawberryAnnotation
Expand Down Expand Up @@ -282,30 +283,27 @@ def qs_hook(qs: models.QuerySet): # type: ignore

def qs_hook(qs: models.QuerySet): # type: ignore
qs = self.get_queryset(qs, info, **kwargs)
# Don't use qs.first() if the queryset is optimized by prefetching.
# Calling first in that case would disregard the prefetched results, because first implicitly
# adds a limit to the query
if is_optimized_by_prefetching(qs):
return next(iter(qs), None)
return qs.first()

else:

def qs_hook(qs: models.QuerySet):
qs = self.get_queryset(qs, info, **kwargs)
# See comment above about qs.first(), the same applies for get()
# Don't use qs.get() if the queryset is optimized by prefetching.
# Calling first in that case would disregard the prefetched results, because first implicitly
# adds a limit to the query
if is_optimized_by_prefetching(qs):
# mimic behavior of get()
qs_len = len(
qs
) # the queryset is already prefetched, no issue with just using len()
# the queryset is already prefetched, no issue with just using len()
qs_len = len(qs)
if qs_len == 0:
raise qs.model.DoesNotExist(
f"{qs.model._meta.object_name} matching query does not exist."
)
if qs_len != 1:
raise qs.model.MultipleObjectsReturned(
f"get() returned more than one {qs.model._meta.object_name} -- it returned {qs_len}!"
f"get() returned more than one {qs.model._meta.object_name} -- it returned "
f"{qs_len if qs_len < MAX_GET_RESULTS else f'more than {qs_len - 1}'}!"
)
return qs[0]
return qs.get()
Expand Down
14 changes: 14 additions & 0 deletions strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.db import DEFAULT_DB_ALIAS
from django.db.models import Count, QuerySet, Window
from django.db.models.functions import RowNumber
from django.db.models.query import MAX_GET_RESULTS

Check failure on line 9 in strawberry_django/pagination.py

View workflow job for this annotation

GitHub Actions / Typing

"MAX_GET_RESULTS" is unknown import symbol (reportAttributeAccessIssue)
from strawberry.types import Info
from strawberry.types.arguments import StrawberryArgument
from strawberry.types.unset import UNSET, UnsetType
Expand Down Expand Up @@ -343,6 +344,19 @@ def get_queryset(
if self.is_paginated:
return queryset

# Add implicit pagination if this field is not a list
# that way when first() / get() is called on the QuerySet it does not cause extra queries
if not pagination and not (
self.is_list or self.is_paginated or self.is_connection
):
if self.is_optional:
# first() applies order by pk if not ordered
if not queryset.ordered:
queryset = queryset.order_by("pk")
pagination = OffsetPaginationInput(offset=0, limit=1)
else:
pagination = OffsetPaginationInput(offset=0, limit=MAX_GET_RESULTS)

return self.apply_pagination(
queryset,
pagination,
Expand Down

0 comments on commit 3f81116

Please sign in to comment.