diff --git a/platformics/api/relay/fields.py b/platformics/api/relay/fields.py index b33c8a3..1b06b6a 100644 --- a/platformics/api/relay/fields.py +++ b/platformics/api/relay/fields.py @@ -83,8 +83,17 @@ def resolver( info: Info, id: Annotated[strawberry.ID, argument(description="The ID of the object.")], ): - return id.resolve_type(info).resolve_node( - id.node_id, + type_resolvers = [] + for selected_type in info.selected_fields[0].selections: + field_type = selected_type.type_condition + type_def = info.schema.get_type_by_name(field_type) + origin = type_def.origin.resolve_type if isinstance(type_def.origin, LazyType) else type_def.origin + assert issubclass(origin, Node) + type_resolvers.append(origin) + # FIXME TODO this only works if we're getting a *single* subclassed `Node` type -- + # if we're getting multiple subclass types, we need to resolve them all somehow + return type_resolvers[0].resolve_node( + id, info=info, required=not is_optional, ) diff --git a/platformics/api/types/entities.py b/platformics/api/types/entities.py index 358c5ba..0880eac 100644 --- a/platformics/api/types/entities.py +++ b/platformics/api/types/entities.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Any, Iterable import strawberry from strawberry.types import Info @@ -29,3 +29,13 @@ async def resolve_nodes(cls, *, info: Info, node_ids: Iterable[str], required: b gql_type: str = cls.__strawberry_definition__.name # type: ignore sql_model = getattr(db_module, gql_type) return await dataloader.resolve_nodes(sql_model, node_ids) + + @classmethod + async def resolve_node(cls, node_id: str, info: Info, required: bool = False) -> Any: + dataloader = info.context["sqlalchemy_loader"] + db_module = info.context["db_module"] + gql_type: str = cls.__strawberry_definition__.name # type: ignore + sql_model = getattr(db_module, gql_type) + res = await dataloader.resolve_nodes(sql_model, [node_id]) + if res: + return res[0] diff --git a/platformics/test_infra/factories/base.py b/platformics/test_infra/factories/base.py index a3ce5ec..58291ef 100644 --- a/platformics/test_infra/factories/base.py +++ b/platformics/test_infra/factories/base.py @@ -100,7 +100,8 @@ def update_file_ids(cls) -> None: session.execute( sa.text( f"""UPDATE {entity_name} SET {entity_field_name}_id = file.id - FROM file WHERE {entity_name}.entity_id = file.entity_id""", + FROM file WHERE {entity_name}.entity_id = file.entity_id and file.entity_field_name = :field_name""", ), + {"field_name": entity_field_name}, ) session.commit() diff --git a/test_app/Makefile b/test_app/Makefile index a13b6ce..656dca9 100644 --- a/test_app/Makefile +++ b/test_app/Makefile @@ -44,10 +44,9 @@ init: $(docker_compose_run) $(APP_CONTAINER) black . # $(docker_compose_run) $(CONTAINER) ruff check --fix . $(docker_compose_run) $(APP_CONTAINER) sh -c 'strawberry export-schema main:schema > /app/api/schema.graphql' - sleep 5 # wait for the app to reload after having files updated. docker compose up -d $(MAKE) seed-moto - sleep 5 + sleep 5 # wait for the app to reload after having files updated. docker compose exec $(APP_CONTAINER) python3 -m sgqlc.introspection --exclude-deprecated --exclude-description http://localhost:9009/graphql api/schema.json .PHONY: seed-moto @@ -62,6 +61,7 @@ clean: ## Remove all codegen'd artifacts. rm -rf support rm -rf database rm -f .moto_recording + rm -rf test_infra $(docker_compose) --profile '*' down .PHONY: start diff --git a/test_app/schema/schema.yaml b/test_app/schema/schema.yaml index b111938..b22188f 100644 --- a/test_app/schema/schema.yaml +++ b/test_app/schema/schema.yaml @@ -201,8 +201,10 @@ classes: primer_file: range: GenomicRange inverse: GenomicRange.sequencing_reads + # Only mutable by system user (needed for upload flow) annotations: - mutable: false + mutable: true + system_writable_only: True contig: range: Contig inverse: Contig.sequencing_reads diff --git a/test_app/tests/test_nested_queries.py b/test_app/tests/test_nested_queries.py index 3a0d8f8..1462ee9 100644 --- a/test_app/tests/test_nested_queries.py +++ b/test_app/tests/test_nested_queries.py @@ -2,21 +2,12 @@ Tests for nested queries + authorization """ -import base64 import pytest from collections import defaultdict from platformics.database.connect import SyncDB from conftest import GQLTestClient, SessionStorage from test_infra.factories.sample import SampleFactory from test_infra.factories.sequencing_read import SequencingReadFactory -from platformics.api.types.entities import Entity - - -def get_id(entity: Entity) -> str: - entity_type = entity.__class__.__name__ - node_id = f"{entity_type}:{entity.id}".encode("ascii") - node_id_b64 = base64.b64encode(node_id).decode("utf-8") - return node_id_b64 @pytest.mark.asyncio @@ -155,9 +146,9 @@ async def test_relay_node_queries( sample1 = SampleFactory(owner_user_id=111, collection_id=888) sample2 = SampleFactory(owner_user_id=111, collection_id=888) sequencing_read = SequencingReadFactory(sample=sample1, owner_user_id=111, collection_id=888) - sample1_id = get_id(sample1) - sample2_id = get_id(sample2) - sequencing_read_id = get_id(sequencing_read) + sample1_id = sample1.id + sample2_id = sample2.id + sequencing_read_id = sequencing_read.id # Fetch one node query = f""" diff --git a/test_app/tests/test_where_clause.py b/test_app/tests/test_where_clause.py index 3963df6..e4613e2 100644 --- a/test_app/tests/test_where_clause.py +++ b/test_app/tests/test_where_clause.py @@ -271,6 +271,7 @@ async def test_soft_deleted_objects(sync_db: SyncDB, gql_client: GQLTestClient) }} }} """ + # Only service identities are allowed to soft delete entities output = await gql_client.query(soft_delete_mutation, member_projects=[project_id], service_identity="workflows") assert len(output["data"]["updateSequencingRead"]) == 3