Skip to content

Commit

Permalink
fix: Prevent a possible security issue when resolving a relay model w…
Browse files Browse the repository at this point in the history
…ith multiple possibilities

Follow up to the security fix on Strawberry, which requires changes
to this integration as well.
  • Loading branch information
bellini666 committed Jan 8, 2025
1 parent a30755f commit d1a3e94
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
16 changes: 14 additions & 2 deletions strawberry_django/relay.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
import inspect
import warnings
from collections.abc import Iterable, Sized
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
Expand Down Expand Up @@ -30,6 +32,7 @@
get_django_definition,
)

_T = TypeVar("_T")
_M = TypeVar("_M", bound=models.Model)


Expand Down Expand Up @@ -174,6 +177,13 @@ def resolve_connection_from_cache(
)


def get_node_caster(origin: Optional[type]) -> Callable[[_T], _T]:
if origin is None:
return lambda node: node

return functools.partial(strawberry.cast, origin)


@overload
def resolve_model_nodes(
source: Union[
Expand Down Expand Up @@ -341,7 +351,8 @@ def resolve_model_nodes(
return retval

def map_results(results: models.QuerySet[_M]) -> list[_M]:
results_map = {str(getattr(obj, id_attr)): obj for obj in results}
node_caster = get_node_caster(origin)
results_map = {str(getattr(obj, id_attr)): node_caster(obj) for obj in results}
retval: list[Optional[_M]] = []
for node_id in node_ids:
if required:
Expand Down Expand Up @@ -449,7 +460,8 @@ def resolve_model_node(
# If optimizer extension is enabled, optimize this queryset
qs = ext.optimize(qs, info=info)

return django_resolver(lambda: qs.get() if required else qs.first())()
node_caster = get_node_caster(origin)
return django_resolver(lambda: node_caster(qs.get() if required else qs.first()))()


def resolve_model_id_attr(source: type) -> str:
Expand Down
10 changes: 9 additions & 1 deletion strawberry_django/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from strawberry.types import get_object_definition
from strawberry.types.base import WithStrawberryObjectDefinition
from strawberry.types.cast import is_strawberry_cast_obj
from strawberry.types.field import StrawberryField
from strawberry.types.private import is_private
from strawberry.utils.deprecations import DeprecatedDescriptor
Expand Down Expand Up @@ -185,7 +186,14 @@ def _process_type(

# Make sure model is also considered a "virtual subclass" of cls
if "is_type_of" not in cls.__dict__:
cls.is_type_of = lambda obj, info: isinstance(obj, (cls, model))

def is_type_of(obj, info):
# XXX: Check if this is required even with the strawberry upstream changes
if is_strawberry_cast_obj(obj):
return obj.__as_strawberry_type__ is cls
return isinstance(obj, (cls, model))

cls.is_type_of = is_type_of

# Default querying methods for relay
if issubclass(cls, relay.Node):
Expand Down
42 changes: 42 additions & 0 deletions tests/relay/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import strawberry
from strawberry import relay

import strawberry_django
from tests.projects.models import Project


@pytest.mark.parametrize("type_name", ["ProjectType", "PublicProjectObject"])
@pytest.mark.django_db(transaction=True)
def test_correct_model_returned(type_name: str):
@strawberry_django.type(Project)
class ProjectType(relay.Node):
name: relay.NodeID[str]
due_date: strawberry.auto

@strawberry_django.type(Project)
class PublicProjectObject(relay.Node):
name: relay.NodeID[str]
due_date: strawberry.auto

@strawberry.type
class Query:
node: relay.Node = relay.node()

schema = strawberry.Schema(query=Query, types=[ProjectType, PublicProjectObject])
Project.objects.create(name="test")

node_id = relay.to_base64(type_name, "test")
result = schema.execute_sync(
"""
query NodeQuery($id: GlobalID!) {
node(id: $id) {
__typename
id
}
}
""",
{"id": node_id},
)
assert result.errors is None
assert result.data == {"node": {"__typename": type_name, "id": node_id}}

0 comments on commit d1a3e94

Please sign in to comment.