diff --git a/strawberry_django/relay.py b/strawberry_django/relay.py index 5d28de85..cdf5db6f 100644 --- a/strawberry_django/relay.py +++ b/strawberry_django/relay.py @@ -1,8 +1,10 @@ +import functools import inspect import warnings from collections.abc import Iterable, Sized from typing import ( Any, + Callable, Optional, TypeVar, Union, @@ -30,6 +32,7 @@ get_django_definition, ) +_T = TypeVar("_T") _M = TypeVar("_M", bound=models.Model) @@ -174,6 +177,13 @@ def resolve_connection_from_cache( ) +def get_node_caster(origin: Optional[type]) -> Callable[[_T], _T]: + if origin is None: + return lambda node: node + + return functools.partial(strawberry.cast, origin) + + @overload def resolve_model_nodes( source: Union[ @@ -341,7 +351,8 @@ def resolve_model_nodes( return retval def map_results(results: models.QuerySet[_M]) -> list[_M]: - results_map = {str(getattr(obj, id_attr)): obj for obj in results} + node_caster = get_node_caster(origin) + results_map = {str(getattr(obj, id_attr)): node_caster(obj) for obj in results} retval: list[Optional[_M]] = [] for node_id in node_ids: if required: @@ -449,7 +460,8 @@ def resolve_model_node( # If optimizer extension is enabled, optimize this queryset qs = ext.optimize(qs, info=info) - return django_resolver(lambda: qs.get() if required else qs.first())() + node_caster = get_node_caster(origin) + return django_resolver(lambda: node_caster(qs.get() if required else qs.first()))() def resolve_model_id_attr(source: type) -> str: diff --git a/strawberry_django/type.py b/strawberry_django/type.py index d3f7568a..1a58deb8 100644 --- a/strawberry_django/type.py +++ b/strawberry_django/type.py @@ -26,6 +26,7 @@ ) from strawberry.types import get_object_definition from strawberry.types.base import WithStrawberryObjectDefinition +from strawberry.types.cast import is_strawberry_cast_obj from strawberry.types.field import StrawberryField from strawberry.types.private import is_private from strawberry.utils.deprecations import DeprecatedDescriptor @@ -185,7 +186,14 @@ def _process_type( # Make sure model is also considered a "virtual subclass" of cls if "is_type_of" not in cls.__dict__: - cls.is_type_of = lambda obj, info: isinstance(obj, (cls, model)) + + def is_type_of(obj, info): + # XXX: Check if this is required even with the strawberry upstream changes + if is_strawberry_cast_obj(obj): + return obj.__as_strawberry_type__ is cls + return isinstance(obj, (cls, model)) + + cls.is_type_of = is_type_of # Default querying methods for relay if issubclass(cls, relay.Node): diff --git a/tests/relay/test_query.py b/tests/relay/test_query.py new file mode 100644 index 00000000..7c65cb6f --- /dev/null +++ b/tests/relay/test_query.py @@ -0,0 +1,42 @@ +import pytest +import strawberry +from strawberry import relay + +import strawberry_django +from tests.projects.models import Project + + +@pytest.mark.parametrize("type_name", ["ProjectType", "PublicProjectObject"]) +@pytest.mark.django_db(transaction=True) +def test_correct_model_returned(type_name: str): + @strawberry_django.type(Project) + class ProjectType(relay.Node): + name: relay.NodeID[str] + due_date: strawberry.auto + + @strawberry_django.type(Project) + class PublicProjectObject(relay.Node): + name: relay.NodeID[str] + due_date: strawberry.auto + + @strawberry.type + class Query: + node: relay.Node = relay.node() + + schema = strawberry.Schema(query=Query, types=[ProjectType, PublicProjectObject]) + Project.objects.create(name="test") + + node_id = relay.to_base64(type_name, "test") + result = schema.execute_sync( + """ + query NodeQuery($id: GlobalID!) { + node(id: $id) { + __typename + id + } + } + """, + {"id": node_id}, + ) + assert result.errors is None + assert result.data == {"node": {"__typename": type_name, "id": node_id}}