From f00300d26754f286629915df5ba0ac38131547d7 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Sun, 15 Dec 2024 15:56:03 +0000 Subject: [PATCH] Clean up PrefectDBInterface, use the models consistently --- docs/v3/api-ref/rest-api/server/schema.json | 22 +- src/prefect/cli/server.py | 6 +- src/prefect/client/schemas/objects.py | 9 + src/prefect/events/filters.py | 10 +- src/prefect/server/__init__.py | 2 + src/prefect/server/api/admin.py | 3 +- src/prefect/server/api/artifacts.py | 3 +- src/prefect/server/api/automations.py | 3 +- src/prefect/server/api/block_capabilities.py | 5 +- src/prefect/server/api/block_documents.py | 3 +- src/prefect/server/api/block_schemas.py | 3 +- src/prefect/server/api/block_types.py | 3 +- src/prefect/server/api/concurrency_limits.py | 3 +- .../server/api/concurrency_limits_v2.py | 3 +- src/prefect/server/api/csrf_token.py | 3 +- src/prefect/server/api/deployments.py | 3 +- src/prefect/server/api/events.py | 3 +- .../api/flow_run_notification_policies.py | 3 +- src/prefect/server/api/flow_run_states.py | 3 +- src/prefect/server/api/flow_runs.py | 9 +- src/prefect/server/api/flows.py | 3 +- src/prefect/server/api/logs.py | 3 +- src/prefect/server/api/middleware.py | 2 +- src/prefect/server/api/root.py | 3 +- src/prefect/server/api/run_history.py | 18 +- src/prefect/server/api/saved_searches.py | 3 +- src/prefect/server/api/server.py | 4 +- src/prefect/server/api/task_run_states.py | 3 +- src/prefect/server/api/task_runs.py | 3 +- src/prefect/server/api/ui/flow_runs.py | 18 +- src/prefect/server/api/ui/flows.py | 12 +- src/prefect/server/api/ui/schemas.py | 3 +- src/prefect/server/api/ui/task_runs.py | 6 +- src/prefect/server/api/variables.py | 3 +- src/prefect/server/api/work_queues.py | 7 +- src/prefect/server/api/workers.py | 3 +- src/prefect/server/database/__init__.py | 13 + .../server/database/alembic_commands.py | 26 +- src/prefect/server/database/configurations.py | 126 +- src/prefect/server/database/dependencies.py | 38 +- src/prefect/server/database/interface.py | 239 +-- src/prefect/server/database/orm_models.py | 22 +- .../server/database/query_components.py | 1586 ++++++++--------- src/prefect/server/events/counting.py | 3 +- src/prefect/server/events/filters.py | 201 +-- .../server/events/models/automations.py | 18 +- .../models/composite_trigger_child_firing.py | 3 +- src/prefect/server/events/ordering.py | 30 +- .../server/events/services/event_persister.py | 2 +- src/prefect/server/events/storage/database.py | 15 +- src/prefect/server/events/triggers.py | 9 +- src/prefect/server/models/agents.py | 30 +- src/prefect/server/models/artifacts.py | 126 +- src/prefect/server/models/block_documents.py | 93 +- src/prefect/server/models/block_schemas.py | 122 +- src/prefect/server/models/block_types.py | 43 +- .../server/models/concurrency_limits.py | 56 +- .../server/models/concurrency_limits_v2.py | 106 +- src/prefect/server/models/configuration.py | 14 +- src/prefect/server/models/csrf_token.py | 23 +- src/prefect/server/models/deployments.py | 248 +-- src/prefect/server/models/flow_run_input.py | 34 +- .../models/flow_run_notification_policies.py | 38 +- src/prefect/server/models/flow_run_states.py | 20 +- src/prefect/server/models/flow_runs.py | 102 +- src/prefect/server/models/flows.py | 78 +- src/prefect/server/models/logs.py | 12 +- src/prefect/server/models/saved_searches.py | 37 +- src/prefect/server/models/task_run_states.py | 22 +- src/prefect/server/models/task_runs.py | 91 +- src/prefect/server/models/variables.py | 59 +- src/prefect/server/models/work_queues.py | 65 +- src/prefect/server/models/workers.py | 121 +- .../server/orchestration/core_policy.py | 3 +- src/prefect/server/orchestration/rules.py | 3 +- src/prefect/server/schemas/filters.py | 985 +++++----- src/prefect/server/schemas/graph.py | 24 +- src/prefect/server/schemas/sorting.py | 172 +- src/prefect/server/schemas/states.py | 141 +- .../server/services/cancellation_cleanup.py | 35 +- .../server/services/flow_run_notifications.py | 8 +- src/prefect/server/services/foreman.py | 54 +- src/prefect/server/services/late_runs.py | 3 +- src/prefect/server/services/loop_service.py | 3 +- .../server/services/pause_expirations.py | 3 +- src/prefect/server/services/scheduler.py | 4 +- .../server/services/task_run_recorder.py | 11 +- src/prefect/server/services/telemetry.py | 3 +- src/prefect/server/utilities/database.py | 17 +- .../test_pausing_resuming_work_pool.py | 2 +- tests/events/server/conftest.py | 2 +- .../events/server/models/test_automations.py | 2 +- tests/events/server/storage/test_database.py | 2 +- .../server/storage/test_event_persister.py | 3 +- tests/events/server/test_automations_api.py | 2 +- .../triggers/test_composite_triggers.py | 2 +- tests/fixtures/database.py | 6 +- tests/server/api/test_csrf_token.py | 2 +- tests/server/api/test_middleware.py | 2 +- tests/server/database/test_dependencies.py | 10 +- tests/server/database/test_queries.py | 2 +- .../models/test_concurrency_limits_v2.py | 2 +- tests/server/models/test_task_run_states.py | 10 +- .../api/test_concurrency_limits_v2.py | 2 +- .../api/test_deployment_schedules.py | 2 +- .../api/test_flow_run_graph_v2.py | 11 +- .../server/orchestration/api/test_workers.py | 12 +- tests/server/orchestration/test_rules.py | 2 +- tests/server/services/test_foreman.py | 3 +- tests/server/utilities/test_database.py | 15 +- ui-v2/src/api/prefect.ts | 9 +- 111 files changed, 2918 insertions(+), 2723 deletions(-) diff --git a/docs/v3/api-ref/rest-api/server/schema.json b/docs/v3/api-ref/rest-api/server/schema.json index 62986c65c9610..d740853819c7c 100644 --- a/docs/v3/api-ref/rest-api/server/schema.json +++ b/docs/v3/api-ref/rest-api/server/schema.json @@ -939,7 +939,7 @@ "type": "string", "format": "date-time", "description": "Only include runs that start or end after this time.", - "default": "0001-01-01T00:00:00", + "default": "0001-01-01T00:00:00+00:00", "title": "Since" }, "description": "Only include runs that start or end after this time." @@ -20044,8 +20044,15 @@ "Graph": { "properties": { "start_time": { - "type": "string", - "format": "date-time", + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], "title": "Start Time" }, "end_time": { @@ -20136,7 +20143,14 @@ "title": "Key" }, "type": { - "type": "string", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], "title": "Type" }, "is_latest": { diff --git a/src/prefect/cli/server.py b/src/prefect/cli/server.py index 65627f3d16794..fb66042be9192 100644 --- a/src/prefect/cli/server.py +++ b/src/prefect/cli/server.py @@ -340,7 +340,7 @@ async def stop(): @database_app.command() async def reset(yes: bool = typer.Option(False, "--yes", "-y")): """Drop and recreate all Prefect database tables""" - from prefect.server.database.dependencies import provide_database_interface + from prefect.server.database import provide_database_interface db = provide_database_interface() engine = await db.engine() @@ -378,8 +378,8 @@ async def upgrade( ), ): """Upgrade the Prefect database""" + from prefect.server.database import provide_database_interface from prefect.server.database.alembic_commands import alembic_upgrade - from prefect.server.database.dependencies import provide_database_interface db = provide_database_interface() engine = await db.engine() @@ -418,8 +418,8 @@ async def downgrade( ), ): """Downgrade the Prefect database""" + from prefect.server.database import provide_database_interface from prefect.server.database.alembic_commands import alembic_downgrade - from prefect.server.database.dependencies import provide_database_interface db = provide_database_interface() diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index 18bd75ab305f8..58a271ed73a25 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -237,6 +237,15 @@ def result( ) -> Union[R, Exception]: ... + @overload + def result( + self: "State[R]", + raise_on_failure: bool = ..., + fetch: bool = ..., + retry_result_failure: bool = ..., + ) -> Union[R, Exception]: + ... + @deprecated.deprecated_parameter( "fetch", when=lambda fetch: fetch is not True, diff --git a/src/prefect/events/filters.py b/src/prefect/events/filters.py index f969e9ccb6514..657407d43d206 100644 --- a/src/prefect/events/filters.py +++ b/src/prefect/events/filters.py @@ -2,7 +2,7 @@ from uuid import UUID import pendulum -from pydantic import Field, PrivateAttr +from pydantic import Field from prefect._internal.schemas.bases import PrefectBaseModel from prefect.types import DateTime @@ -41,18 +41,12 @@ class AutomationFilter(PrefectBaseModel): class EventDataFilter(PrefectBaseModel, extra="forbid"): # type: ignore[call-arg] """A base class for filtering event data.""" - _top_level_filter: Optional["EventFilter"] = PrivateAttr(None) - def get_filters(self) -> List["EventDataFilter"]: filters: List["EventDataFilter"] = [ filter - for filter in [ - getattr(self, name) for name, field in self.model_fields.items() - ] + for filter in [getattr(self, name) for name in self.model_fields] if isinstance(filter, EventDataFilter) ] - for filter in filters: - filter._top_level_filter = self._top_level_filter return filters def includes(self, event: Event) -> bool: diff --git a/src/prefect/server/__init__.py b/src/prefect/server/__init__.py index c9bc11f901d54..73ded801bcfed 100644 --- a/src/prefect/server/__init__.py +++ b/src/prefect/server/__init__.py @@ -1 +1,3 @@ from . import models, orchestration, schemas, services + +__all__ = ["models", "orchestration", "schemas", "services"] diff --git a/src/prefect/server/api/admin.py b/src/prefect/server/api/admin.py index 6b35f2baf8b2c..b55eb0bafd495 100644 --- a/src/prefect/server/api/admin.py +++ b/src/prefect/server/api/admin.py @@ -6,8 +6,7 @@ import prefect import prefect.settings -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/admin", tags=["Admin"]) diff --git a/src/prefect/server/api/artifacts.py b/src/prefect/server/api/artifacts.py index e7b4f4017fe04..d3300b1a348f1 100644 --- a/src/prefect/server/api/artifacts.py +++ b/src/prefect/server/api/artifacts.py @@ -10,8 +10,7 @@ import prefect.server.api.dependencies as dependencies from prefect.server import models -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.schemas import actions, core, filters, sorting from prefect.server.utilities.server import PrefectRouter diff --git a/src/prefect/server/api/automations.py b/src/prefect/server/api/automations.py index 8c74c47b1f22d..832e66515f17a 100644 --- a/src/prefect/server/api/automations.py +++ b/src/prefect/server/api/automations.py @@ -10,8 +10,7 @@ from prefect.server.api.validation import ( validate_job_variables_for_run_deployment_action, ) -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.events import actions from prefect.server.events.filters import AutomationFilter, AutomationFilterCreated from prefect.server.events.models import automations as automations_models diff --git a/src/prefect/server/api/block_capabilities.py b/src/prefect/server/api/block_capabilities.py index a9b26cf17a754..3da721adaf30e 100644 --- a/src/prefect/server/api/block_capabilities.py +++ b/src/prefect/server/api/block_capabilities.py @@ -7,10 +7,7 @@ from fastapi import Depends from prefect.server import models -from prefect.server.database.dependencies import ( - PrefectDBInterface, - provide_database_interface, -) +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/block_capabilities", tags=["Block capabilities"]) diff --git a/src/prefect/server/api/block_documents.py b/src/prefect/server/api/block_documents.py index ca0016be4dafa..dd8deb82b6dc5 100644 --- a/src/prefect/server/api/block_documents.py +++ b/src/prefect/server/api/block_documents.py @@ -9,8 +9,7 @@ from prefect.server import models, schemas from prefect.server.api import dependencies -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/block_documents", tags=["Block documents"]) diff --git a/src/prefect/server/api/block_schemas.py b/src/prefect/server/api/block_schemas.py index a71564f4f1862..d079f24187ce8 100644 --- a/src/prefect/server/api/block_schemas.py +++ b/src/prefect/server/api/block_schemas.py @@ -17,8 +17,7 @@ from prefect.server import models, schemas from prefect.server.api import dependencies -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.models.block_schemas import MissingBlockTypeException from prefect.server.utilities.server import PrefectRouter diff --git a/src/prefect/server/api/block_types.py b/src/prefect/server/api/block_types.py index 40ea261d2cfee..50351781dc646 100644 --- a/src/prefect/server/api/block_types.py +++ b/src/prefect/server/api/block_types.py @@ -7,8 +7,7 @@ from prefect.blocks.core import _should_update_block_type from prefect.server import models, schemas from prefect.server.api import dependencies -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/block_types", tags=["Block types"]) diff --git a/src/prefect/server/api/concurrency_limits.py b/src/prefect/server/api/concurrency_limits.py index 9414ff018a15f..990cac57feeba 100644 --- a/src/prefect/server/api/concurrency_limits.py +++ b/src/prefect/server/api/concurrency_limits.py @@ -12,8 +12,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas from prefect.server.api.concurrency_limits_v2 import MinimalConcurrencyLimitResponse -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.models import concurrency_limits from prefect.server.utilities.server import PrefectRouter from prefect.settings import PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS diff --git a/src/prefect/server/api/concurrency_limits_v2.py b/src/prefect/server/api/concurrency_limits_v2.py index 67ae928b04ae8..76d96313c060a 100644 --- a/src/prefect/server/api/concurrency_limits_v2.py +++ b/src/prefect/server/api/concurrency_limits_v2.py @@ -6,8 +6,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas from prefect.server.api.dependencies import LimitBody -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.schemas import actions from prefect.server.utilities.schemas import PrefectBaseModel from prefect.server.utilities.server import PrefectRouter diff --git a/src/prefect/server/api/csrf_token.py b/src/prefect/server/api/csrf_token.py index 08b6540e0a4c8..94eaa399da7f3 100644 --- a/src/prefect/server/api/csrf_token.py +++ b/src/prefect/server/api/csrf_token.py @@ -3,8 +3,7 @@ from prefect.logging import get_logger from prefect.server import models, schemas -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter from prefect.settings import PREFECT_SERVER_CSRF_PROTECTION_ENABLED diff --git a/src/prefect/server/api/deployments.py b/src/prefect/server/api/deployments.py index 347b8f33a60a1..b0b89b721f79f 100644 --- a/src/prefect/server/api/deployments.py +++ b/src/prefect/server/api/deployments.py @@ -19,8 +19,7 @@ validate_job_variables_for_deployment_flow_run, ) from prefect.server.api.workers import WorkerLookups -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.exceptions import MissingVariableError, ObjectNotFoundError from prefect.server.models.deployments import mark_deployments_ready from prefect.server.models.workers import DEFAULT_AGENT_WORK_POOL_NAME diff --git a/src/prefect/server/api/events.py b/src/prefect/server/api/events.py index 70815fa50a192..87f4f914f484a 100644 --- a/src/prefect/server/api/events.py +++ b/src/prefect/server/api/events.py @@ -11,8 +11,7 @@ from prefect.logging import get_logger from prefect.server.api.dependencies import is_ephemeral_request -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.events import messaging, stream from prefect.server.events.counting import ( Countable, diff --git a/src/prefect/server/api/flow_run_notification_policies.py b/src/prefect/server/api/flow_run_notification_policies.py index 7ce50351164b3..c0a6da90274ba 100644 --- a/src/prefect/server/api/flow_run_notification_policies.py +++ b/src/prefect/server/api/flow_run_notification_policies.py @@ -10,8 +10,7 @@ import prefect.server.api.dependencies as dependencies import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter( diff --git a/src/prefect/server/api/flow_run_states.py b/src/prefect/server/api/flow_run_states.py index d4eb1785d2773..72784f0175a11 100644 --- a/src/prefect/server/api/flow_run_states.py +++ b/src/prefect/server/api/flow_run_states.py @@ -9,8 +9,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/flow_run_states", tags=["Flow Run States"]) diff --git a/src/prefect/server/api/flow_runs.py b/src/prefect/server/api/flow_runs.py index bace90d0f8e4c..b3e36cf2a0b27 100644 --- a/src/prefect/server/api/flow_runs.py +++ b/src/prefect/server/api/flow_runs.py @@ -20,6 +20,7 @@ Response, status, ) +from fastapi.encoders import jsonable_encoder from fastapi.responses import ORJSONResponse, PlainTextResponse, StreamingResponse from sqlalchemy.exc import IntegrityError @@ -29,8 +30,7 @@ from prefect.logging import get_logger from prefect.server.api.run_history import run_history from prefect.server.api.validation import validate_job_variables_for_deployment_flow_run -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.exceptions import FlowRunGraphTooLarge from prefect.server.models.flow_runs import ( DependencyResult, @@ -215,6 +215,7 @@ async def average_flow_run_lateness( base_query = db.FlowRun.estimated_start_time_delta query = await models.flow_runs._apply_flow_run_filters( + db, sa.select(sa.func.avg(base_query)), flow_filter=flows, flow_run_filter=flow_runs, @@ -321,8 +322,8 @@ async def read_flow_run_graph_v1( @router.get("/{id:uuid}/graph-v2") async def read_flow_run_graph_v2( flow_run_id: UUID = Path(..., description="The flow run id", alias="id"), - since: datetime.datetime = Query( - datetime.datetime.min, + since: DateTime = Query( + default=jsonable_encoder(DateTime.min), description="Only include runs that start or end after this time.", ), db: PrefectDBInterface = Depends(provide_database_interface), diff --git a/src/prefect/server/api/flows.py b/src/prefect/server/api/flows.py index aedb2eaf88a5b..4f7b5696c88d8 100644 --- a/src/prefect/server/api/flows.py +++ b/src/prefect/server/api/flows.py @@ -12,8 +12,7 @@ import prefect.server.api.dependencies as dependencies import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.schemas.responses import FlowPaginationResponse from prefect.server.utilities.server import PrefectRouter diff --git a/src/prefect/server/api/logs.py b/src/prefect/server/api/logs.py index 71f97820761ab..18b723c1f91f8 100644 --- a/src/prefect/server/api/logs.py +++ b/src/prefect/server/api/logs.py @@ -9,8 +9,7 @@ import prefect.server.api.dependencies as dependencies import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/logs", tags=["Logs"]) diff --git a/src/prefect/server/api/middleware.py b/src/prefect/server/api/middleware.py index 6a34d8c81e3d2..75f1786154daa 100644 --- a/src/prefect/server/api/middleware.py +++ b/src/prefect/server/api/middleware.py @@ -7,7 +7,7 @@ from prefect import settings from prefect.server import models -from prefect.server.database.dependencies import provide_database_interface +from prefect.server.database import provide_database_interface NextMiddlewareFunction = Callable[[Request], Awaitable[Response]] diff --git a/src/prefect/server/api/root.py b/src/prefect/server/api/root.py index 1fd399d399b72..8a52153a4b8b6 100644 --- a/src/prefect/server/api/root.py +++ b/src/prefect/server/api/root.py @@ -5,8 +5,7 @@ from fastapi import Depends, status from fastapi.responses import JSONResponse -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="", tags=["Root"]) diff --git a/src/prefect/server/api/run_history.py b/src/prefect/server/api/run_history.py index 932a2076e4ea6..b707b06ed4fbb 100644 --- a/src/prefect/server/api/run_history.py +++ b/src/prefect/server/api/run_history.py @@ -13,8 +13,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas from prefect.logging import get_logger -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.types import DateTime logger = get_logger("server.api") @@ -57,7 +56,7 @@ async def run_history( ) # create a CTE for timestamp intervals - intervals = db.make_timestamp_intervals( + intervals = db.queries.make_timestamp_intervals( history_start, history_end, history_interval, @@ -66,6 +65,7 @@ async def run_history( # apply filters to the flow runs (and related states) runs = ( await run_filter_function( + db, sa.select( run_model.id, run_model.expected_start_time, @@ -92,7 +92,7 @@ async def run_history( # build a JSON object, ignoring the case where the count of runs is 0 sa.case( (sa.func.count(runs.c.id) == 0, None), - else_=db.build_json_object( + else_=db.queries.build_json_object( "state_type", runs.c.state_type, "state_name", @@ -140,10 +140,10 @@ async def run_history( counts.c.interval_start, counts.c.interval_end, sa.func.coalesce( - db.json_arr_agg(db.cast_to_json(counts.c.state_agg)).filter( - counts.c.state_agg.is_not(None) - ), - sa.text("'[]'"), + db.queries.json_arr_agg( + db.queries.cast_to_json(counts.c.state_agg) + ).filter(counts.c.state_agg.is_not(None)), + sa.literal("[]", literal_execute=True), ).label("states"), ) .group_by(counts.c.interval_start, counts.c.interval_end) @@ -157,7 +157,7 @@ async def run_history( records = result.mappings() # load and parse the record if the database returns JSON as strings - if db.uses_json_strings: + if db.queries.uses_json_strings: records = [dict(r) for r in records] for r in records: r["states"] = json.loads(r["states"]) diff --git a/src/prefect/server/api/saved_searches.py b/src/prefect/server/api/saved_searches.py index 1c5928fa33287..fce1314845902 100644 --- a/src/prefect/server/api/saved_searches.py +++ b/src/prefect/server/api/saved_searches.py @@ -11,8 +11,7 @@ import prefect.server.api.dependencies as dependencies import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/saved_searches", tags=["SavedSearches"]) diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py index f6fa3b5879553..1cbae4af2af47 100644 --- a/src/prefect/server/api/server.py +++ b/src/prefect/server/api/server.py @@ -528,7 +528,7 @@ def create_app( async def run_migrations(): """Ensure the database is created and up to date with the current migrations""" if prefect.settings.PREFECT_API_DATABASE_MIGRATE_ON_START: - from prefect.server.database.dependencies import provide_database_interface + from prefect.server.database import provide_database_interface db = provide_database_interface() await db.create_db() @@ -539,7 +539,7 @@ async def add_block_types(): if not prefect.settings.PREFECT_API_BLOCKS_REGISTER_ON_START: return - from prefect.server.database.dependencies import provide_database_interface + from prefect.server.database import provide_database_interface from prefect.server.models.block_registration import run_block_auto_registration db = provide_database_interface() diff --git a/src/prefect/server/api/task_run_states.py b/src/prefect/server/api/task_run_states.py index 041a133d2abb2..ef68e8b4d7a00 100644 --- a/src/prefect/server/api/task_run_states.py +++ b/src/prefect/server/api/task_run_states.py @@ -9,8 +9,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import PrefectRouter router = PrefectRouter(prefix="/task_run_states", tags=["Task Run States"]) diff --git a/src/prefect/server/api/task_runs.py b/src/prefect/server/api/task_runs.py index 8912b3f5fdf02..69d5e14adbd60 100644 --- a/src/prefect/server/api/task_runs.py +++ b/src/prefect/server/api/task_runs.py @@ -24,8 +24,7 @@ import prefect.server.schemas as schemas from prefect.logging import get_logger from prefect.server.api.run_history import run_history -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.orchestration import dependencies as orchestration_dependencies from prefect.server.orchestration.core_policy import CoreTaskPolicy from prefect.server.orchestration.policies import BaseOrchestrationPolicy diff --git a/src/prefect/server/api/ui/flow_runs.py b/src/prefect/server/api/ui/flow_runs.py index c76db8719dca4..a67999f929192 100644 --- a/src/prefect/server/api/ui/flow_runs.py +++ b/src/prefect/server/api/ui/flow_runs.py @@ -7,12 +7,10 @@ from pydantic import Field import prefect.server.schemas as schemas -from prefect._internal.schemas.bases import PrefectBaseModel from prefect.logging import get_logger from prefect.server import models -from prefect.server.database import orm_models -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface +from prefect.server.utilities.schemas.bases import PrefectBaseModel from prefect.server.utilities.server import PrefectRouter from prefect.types import DateTime @@ -101,22 +99,22 @@ async def count_task_runs_by_flow_run( async with db.session_context() as session: query = ( sa.select( - orm_models.TaskRun.flow_run_id, - sa.func.count(orm_models.TaskRun.id).label("task_run_count"), + db.TaskRun.flow_run_id, + sa.func.count(db.TaskRun.id).label("task_run_count"), ) .where( sa.and_( - orm_models.TaskRun.flow_run_id.in_(flow_run_ids), - sa.not_(orm_models.TaskRun.subflow_run.has()), + db.TaskRun.flow_run_id.in_(flow_run_ids), + sa.not_(db.TaskRun.subflow_run.has()), ) ) - .group_by(orm_models.TaskRun.flow_run_id) + .group_by(db.TaskRun.flow_run_id) ) results = await session.execute(query) task_run_counts_by_flow_run = { - flow_run_id: task_run_count for flow_run_id, task_run_count in results.all() + flow_run_id: task_run_count for flow_run_id, task_run_count in results.t } return { diff --git a/src/prefect/server/api/ui/flows.py b/src/prefect/server/api/ui/flows.py index 8128abd45ccf9..8c0b3375d4802 100644 --- a/src/prefect/server/api/ui/flows.py +++ b/src/prefect/server/api/ui/flows.py @@ -8,9 +8,7 @@ from pydantic import Field, field_validator from prefect.logging import get_logger -from prefect.server.database import orm_models -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.schemas.states import StateType from prefect.server.utilities.database import UUID as UUIDTypeDecorator from prefect.server.utilities.schemas import PrefectBaseModel @@ -51,11 +49,11 @@ async def count_deployments_by_flow( async with db.session_context() as session: query = ( sa.select( - orm_models.Deployment.flow_id, - sa.func.count(orm_models.Deployment.id).label("deployment_count"), + db.Deployment.flow_id, + sa.func.count(db.Deployment.id).label("deployment_count"), ) - .where(orm_models.Deployment.flow_id.in_(flow_ids)) - .group_by(orm_models.Deployment.flow_id) + .where(db.Deployment.flow_id.in_(flow_ids)) + .group_by(db.Deployment.flow_id) ) results = await session.execute(query) diff --git a/src/prefect/server/api/ui/schemas.py b/src/prefect/server/api/ui/schemas.py index 2d362f8162da7..768b9c7231d2c 100644 --- a/src/prefect/server/api/ui/schemas.py +++ b/src/prefect/server/api/ui/schemas.py @@ -3,8 +3,7 @@ from fastapi import Body, Depends, HTTPException, status from prefect.logging import get_logger -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.utilities.server import APIRouter from prefect.utilities.schema_tools.hydration import HydrationContext, hydrate from prefect.utilities.schema_tools.validation import ( diff --git a/src/prefect/server/api/ui/task_runs.py b/src/prefect/server/api/ui/task_runs.py index f39565499ba84..938d2a27d57e9 100644 --- a/src/prefect/server/api/ui/task_runs.py +++ b/src/prefect/server/api/ui/task_runs.py @@ -7,11 +7,10 @@ from pydantic import Field, model_serializer import prefect.server.schemas as schemas -from prefect._internal.schemas.bases import PrefectBaseModel from prefect.logging import get_logger from prefect.server import models -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface +from prefect.server.utilities.schemas.bases import PrefectBaseModel from prefect.server.utilities.server import PrefectRouter from prefect.types import DateTime @@ -97,6 +96,7 @@ async def read_dashboard_task_run_counts( raw_counts = ( ( await models.task_runs._apply_task_run_filters( + db, sa.select( bucket_expression, sa.func.min(db.TaskRun.end_time).label("oldest"), diff --git a/src/prefect/server/api/variables.py b/src/prefect/server/api/variables.py index 7b755d6b8b520..b05cb8c648e2b 100644 --- a/src/prefect/server/api/variables.py +++ b/src/prefect/server/api/variables.py @@ -11,8 +11,7 @@ from prefect.server import models from prefect.server.api.dependencies import LimitBody -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.schemas import actions, core, filters, sorting from prefect.server.utilities.server import PrefectRouter diff --git a/src/prefect/server/api/work_queues.py b/src/prefect/server/api/work_queues.py index 101b717c6081c..ad9e17df16453 100644 --- a/src/prefect/server/api/work_queues.py +++ b/src/prefect/server/api/work_queues.py @@ -19,8 +19,11 @@ import prefect.server.api.dependencies as dependencies import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database.dependencies import db_injector, provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import ( + PrefectDBInterface, + db_injector, + provide_database_interface, +) from prefect.server.models.deployments import mark_deployments_ready from prefect.server.models.work_queues import ( emit_work_queue_status_event, diff --git a/src/prefect/server/api/workers.py b/src/prefect/server/api/workers.py index cc40dea8a9d83..0f0d8546dd192 100644 --- a/src/prefect/server/api/workers.py +++ b/src/prefect/server/api/workers.py @@ -21,8 +21,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas from prefect.server.api.validation import validate_job_variable_defaults_for_work_pool -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.server.models.deployments import mark_deployments_ready from prefect.server.models.work_queues import ( emit_work_queue_status_event, diff --git a/src/prefect/server/database/__init__.py b/src/prefect/server/database/__init__.py index e69de29bb2d1d..00100c34478d0 100644 --- a/src/prefect/server/database/__init__.py +++ b/src/prefect/server/database/__init__.py @@ -0,0 +1,13 @@ +from prefect.server.database.dependencies import ( + db_injector, + inject_db, + provide_database_interface, +) +from prefect.server.database.interface import PrefectDBInterface + +__all__ = [ + "PrefectDBInterface", + "db_injector", + "inject_db", + "provide_database_interface", +] diff --git a/src/prefect/server/database/alembic_commands.py b/src/prefect/server/database/alembic_commands.py index bfeec523ee904..fca6c75c4448d 100644 --- a/src/prefect/server/database/alembic_commands.py +++ b/src/prefect/server/database/alembic_commands.py @@ -2,16 +2,24 @@ from functools import wraps from pathlib import Path from threading import Lock -from typing import Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from sqlalchemy.exc import SAWarning +from typing_extensions import ParamSpec, TypeVar import prefect.server.database +if TYPE_CHECKING: + from alembic.config import Config + + +P = ParamSpec("P") +R = TypeVar("R", infer_variance=True) + ALEMBIC_LOCK = Lock() -def with_alembic_lock(fn): +def with_alembic_lock(fn: Callable[P, R]) -> Callable[P, R]: """ Decorator that prevents alembic commands from running concurrently. This is necessary because alembic uses a global configuration object @@ -23,14 +31,14 @@ def with_alembic_lock(fn): """ @wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: with ALEMBIC_LOCK: return fn(*args, **kwargs) return wrapper -def alembic_config(): +def alembic_config() -> "Config": from alembic.config import Config alembic_dir = Path(prefect.server.database.__file__).parent @@ -43,7 +51,7 @@ def alembic_config(): @with_alembic_lock -def alembic_upgrade(revision: str = "head", dry_run: bool = False): +def alembic_upgrade(revision: str = "head", dry_run: bool = False) -> None: """ Run alembic upgrades on Prefect REST API database @@ -65,7 +73,7 @@ def alembic_upgrade(revision: str = "head", dry_run: bool = False): @with_alembic_lock -def alembic_downgrade(revision: str = "-1", dry_run: bool = False): +def alembic_downgrade(revision: str = "-1", dry_run: bool = False) -> None: """ Run alembic downgrades on Prefect REST API database @@ -81,8 +89,8 @@ def alembic_downgrade(revision: str = "-1", dry_run: bool = False): @with_alembic_lock def alembic_revision( - message: Optional[str] = None, autogenerate: bool = False, **kwargs -): + message: Optional[str] = None, autogenerate: bool = False, **kwargs: Any +) -> None: """ Create a new revision file for the database. @@ -99,7 +107,7 @@ def alembic_revision( @with_alembic_lock -def alembic_stamp(revision): +def alembic_stamp(revision: Union[str, list[str], tuple[str, ...]]) -> None: """ Stamp the revision table with the given revision; don't run any migrations diff --git a/src/prefect/server/database/configurations.py b/src/prefect/server/database/configurations.py index 296bf8ba4e5c5..31c93b5eddf8d 100644 --- a/src/prefect/server/database/configurations.py +++ b/src/prefect/server/database/configurations.py @@ -2,16 +2,25 @@ import traceback from abc import ABC, abstractmethod from asyncio import AbstractEventLoop, get_running_loop -from contextlib import asynccontextmanager +from collections.abc import AsyncGenerator, Hashable +from contextlib import AbstractAsyncContextManager, asynccontextmanager from contextvars import ContextVar from functools import partial -from typing import Dict, Hashable, Optional, Tuple +from typing import Any, Optional import sqlalchemy as sa -from sqlalchemy import AdaptedConnection -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy import AdaptedConnection, event +from sqlalchemy.dialects.sqlite import aiosqlite +from sqlalchemy.engine.interfaces import DBAPIConnection +from sqlalchemy.ext.asyncio import ( + AsyncConnection, + AsyncEngine, + AsyncSession, + AsyncSessionTransaction, + create_async_engine, +) from sqlalchemy.pool import ConnectionPoolEntry -from typing_extensions import Literal +from typing_extensions import TypeAlias from prefect.settings import ( PREFECT_API_DATABASE_CONNECTION_TIMEOUT, @@ -27,16 +36,17 @@ "SQLITE_BEGIN_MODE", default=None ) -ENGINES: Dict[Tuple[AbstractEventLoop, str, bool, float], AsyncEngine] = {} +_EngineCacheKey: TypeAlias = tuple[AbstractEventLoop, str, bool, Optional[float]] +ENGINES: dict[_EngineCacheKey, AsyncEngine] = {} class ConnectionTracker: """A test utility which tracks the connections given out by a connection pool, to make it easy to see which connections are currently checked out and open.""" - all_connections: Dict[AdaptedConnection, str] - open_connections: Dict[AdaptedConnection, str] - left_field_closes: Dict[AdaptedConnection, str] + all_connections: dict[AdaptedConnection, list[str]] + open_connections: dict[AdaptedConnection, list[str]] + left_field_closes: dict[AdaptedConnection, list[str]] connects: int closes: int active: bool @@ -49,16 +59,16 @@ def __init__(self) -> None: self.connects = 0 self.closes = 0 - def track_pool(self, pool: sa.pool.Pool): - sa.event.listen(pool, "connect", self.on_connect) - sa.event.listen(pool, "close", self.on_close) - sa.event.listen(pool, "close_detached", self.on_close_detached) + def track_pool(self, pool: sa.pool.Pool) -> None: + event.listen(pool, "connect", self.on_connect) + event.listen(pool, "close", self.on_close) + event.listen(pool, "close_detached", self.on_close_detached) def on_connect( self, adapted_connection: AdaptedConnection, connection_record: ConnectionPoolEntry, - ): + ) -> None: self.all_connections[adapted_connection] = traceback.format_stack() self.open_connections[adapted_connection] = traceback.format_stack() self.connects += 1 @@ -67,7 +77,7 @@ def on_close( self, adapted_connection: AdaptedConnection, connection_record: ConnectionPoolEntry, - ): + ) -> None: try: del self.open_connections[adapted_connection] except KeyError: @@ -77,14 +87,14 @@ def on_close( def on_close_detached( self, adapted_connection: AdaptedConnection, - ): + ) -> None: try: del self.open_connections[adapted_connection] except KeyError: self.left_field_closes[adapted_connection] = traceback.format_stack() self.closes += 1 - def clear(self): + def clear(self) -> None: self.all_connections.clear() self.open_connections.clear() self.left_field_closes.clear() @@ -111,21 +121,21 @@ def __init__( connection_timeout: Optional[float] = None, sqlalchemy_pool_size: Optional[int] = None, sqlalchemy_max_overflow: Optional[int] = None, - ): + ) -> None: self.connection_url = connection_url - self.echo = echo or PREFECT_API_DATABASE_ECHO.value() - self.timeout = timeout or PREFECT_API_DATABASE_TIMEOUT.value() - self.connection_timeout = ( + self.echo: bool = echo or PREFECT_API_DATABASE_ECHO.value() + self.timeout: Optional[float] = timeout or PREFECT_API_DATABASE_TIMEOUT.value() + self.connection_timeout: Optional[float] = ( connection_timeout or PREFECT_API_DATABASE_CONNECTION_TIMEOUT.value() ) - self.sqlalchemy_pool_size = ( + self.sqlalchemy_pool_size: Optional[int] = ( sqlalchemy_pool_size or PREFECT_SQLALCHEMY_POOL_SIZE.value() ) - self.sqlalchemy_max_overflow = ( + self.sqlalchemy_max_overflow: Optional[int] = ( sqlalchemy_max_overflow or PREFECT_SQLALCHEMY_MAX_OVERFLOW.value() ) - def _unique_key(self) -> Tuple[Hashable, ...]: + def unique_key(self) -> tuple[Hashable, ...]: """ Returns a key used to determine whether to instantiate a new DB interface. """ @@ -142,11 +152,15 @@ async def session(self, engine: AsyncEngine) -> AsyncSession: """ @abstractmethod - async def create_db(self, connection, base_metadata): + async def create_db( + self, connection: AsyncConnection, base_metadata: sa.MetaData + ) -> None: """Create the database""" @abstractmethod - async def drop_db(self, connection, base_metadata): + async def drop_db( + self, connection: AsyncConnection, base_metadata: sa.MetaData + ) -> None: """Drop the database""" @abstractmethod @@ -154,9 +168,9 @@ def is_inmemory(self) -> bool: """Returns true if database is run in memory""" @abstractmethod - async def begin_transaction( + def begin_transaction( self, session: AsyncSession, with_for_update: bool = False - ): + ) -> AbstractAsyncContextManager[AsyncSessionTransaction]: """Enter a transaction for a session""" pass @@ -187,8 +201,8 @@ async def engine(self) -> AsyncEngine: ) if cache_key not in ENGINES: # apply database timeout - kwargs = dict() - connect_args = dict() + kwargs: dict[str, Any] = dict() + connect_args: dict[str, Any] = dict() if self.timeout is not None: connect_args["command_timeout"] = self.timeout @@ -226,7 +240,7 @@ async def engine(self) -> AsyncEngine: await self.schedule_engine_disposal(cache_key) return ENGINES[cache_key] - async def schedule_engine_disposal(self, cache_key): + async def schedule_engine_disposal(self, cache_key: _EngineCacheKey) -> None: """ Dispose of an engine once the event loop is closing. @@ -243,7 +257,7 @@ async def schedule_engine_disposal(self, cache_key): encountered should be encouraged to use a standalone server. """ - async def dispose_engine(cache_key): + async def dispose_engine(cache_key: _EngineCacheKey) -> None: engine = ENGINES.pop(cache_key, None) if engine: await engine.dispose() @@ -262,23 +276,27 @@ async def session(self, engine: AsyncEngine) -> AsyncSession: @asynccontextmanager async def begin_transaction( self, session: AsyncSession, with_for_update: bool = False - ): + ) -> AsyncGenerator[AsyncSessionTransaction, None]: # `with_for_update` is for SQLite only. For Postgres, lock the row on read # for update instead. async with session.begin() as transaction: yield transaction - async def create_db(self, connection, base_metadata): + async def create_db( + self, connection: AsyncConnection, base_metadata: sa.MetaData + ) -> None: """Create the database""" await connection.run_sync(base_metadata.create_all) - async def drop_db(self, connection, base_metadata): + async def drop_db( + self, connection: AsyncConnection, base_metadata: sa.MetaData + ) -> None: """Drop the database""" await connection.run_sync(base_metadata.drop_all) - def is_inmemory(self) -> Literal[False]: + def is_inmemory(self) -> bool: """Returns true if database is run in memory""" return False @@ -309,16 +327,11 @@ async def engine(self) -> AsyncEngine: f"{sqlite3.sqlite_version}" ) - kwargs = {} + kwargs: dict[str, Any] = {} loop = get_running_loop() - cache_key = ( - loop, - self.connection_url, - self.echo, - self.timeout, - ) + cache_key = (loop, self.connection_url, self.echo, self.timeout) if cache_key not in ENGINES: # apply database timeout if self.timeout is not None: @@ -351,7 +364,7 @@ async def engine(self) -> AsyncEngine: await self.schedule_engine_disposal(cache_key) return ENGINES[cache_key] - async def schedule_engine_disposal(self, cache_key): + async def schedule_engine_disposal(self, cache_key: _EngineCacheKey) -> None: """ Dispose of an engine once the event loop is closing. @@ -368,19 +381,20 @@ async def schedule_engine_disposal(self, cache_key): encountered should be encouraged to use a standalone server. """ - async def dispose_engine(cache_key): + async def dispose_engine(cache_key: _EngineCacheKey) -> None: engine = ENGINES.pop(cache_key, None) if engine: await engine.dispose() await add_event_loop_shutdown_callback(partial(dispose_engine, cache_key)) - def setup_sqlite(self, conn, record): + def setup_sqlite(self, conn: DBAPIConnection, record: ConnectionPoolEntry) -> None: """Issue PRAGMA statements to SQLITE on connect. PRAGMAs only last for the duration of the connection. See https://www.sqlite.org/pragma.html for more info. """ # workaround sqlite transaction behavior - self.begin_sqlite_conn(conn, record) + if isinstance(conn, aiosqlite.AsyncAdapt_aiosqlite_connection): + self.begin_sqlite_conn(conn) cursor = conn.cursor() @@ -425,14 +439,16 @@ def setup_sqlite(self, conn, record): cursor.close() - def begin_sqlite_conn(self, conn, record): + def begin_sqlite_conn( + self, conn: aiosqlite.AsyncAdapt_aiosqlite_connection + ) -> None: # disable pysqlite's emitting of the BEGIN statement entirely. # also stops it from emitting COMMIT before any DDL. # requires `begin_sqlite_stmt` # see https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl conn.isolation_level = None - def begin_sqlite_stmt(self, conn): + def begin_sqlite_stmt(self, conn: sa.Connection) -> None: # emit our own BEGIN # requires `begin_sqlite_conn` # see https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl @@ -447,7 +463,7 @@ def begin_sqlite_stmt(self, conn): @asynccontextmanager async def begin_transaction( self, session: AsyncSession, with_for_update: bool = False - ): + ) -> AsyncGenerator[AsyncSessionTransaction, None]: token = SQLITE_BEGIN_MODE.set("IMMEDIATE" if with_for_update else "DEFERRED") try: @@ -465,17 +481,21 @@ async def session(self, engine: AsyncEngine) -> AsyncSession: """ return AsyncSession(engine, expire_on_commit=False) - async def create_db(self, connection, base_metadata): + async def create_db( + self, connection: AsyncConnection, base_metadata: sa.MetaData + ) -> None: """Create the database""" await connection.run_sync(base_metadata.create_all) - async def drop_db(self, connection, base_metadata): + async def drop_db( + self, connection: AsyncConnection, base_metadata: sa.MetaData + ) -> None: """Drop the database""" await connection.run_sync(base_metadata.drop_all) - def is_inmemory(self): + def is_inmemory(self) -> bool: """Returns true if database is run in memory""" return ":memory:" in self.connection_url or "mode=memory" in self.connection_url diff --git a/src/prefect/server/database/dependencies.py b/src/prefect/server/database/dependencies.py index 2bf4ca9766f34..6dc3082df930d 100644 --- a/src/prefect/server/database/dependencies.py +++ b/src/prefect/server/database/dependencies.py @@ -32,36 +32,34 @@ AsyncPostgresConfiguration, BaseDatabaseConfiguration, ) -from prefect.server.database.interface import PrefectDBInterface from prefect.server.database.orm_models import ( AioSqliteORMConfiguration, AsyncPostgresORMConfiguration, BaseORMConfiguration, ) -from prefect.server.database.query_components import ( - AioSqliteQueryComponents, - AsyncPostgresQueryComponents, - BaseQueryComponents, -) from prefect.server.utilities.database import get_dialect from prefect.server.utilities.schemas import PrefectDescriptorBase from prefect.settings import PREFECT_API_DATABASE_CONNECTION_URL +if TYPE_CHECKING: + from prefect.server.database.interface import PrefectDBInterface + from prefect.server.database.query_components import BaseQueryComponents + P = ParamSpec("P") R = TypeVar("R", infer_variance=True) T = TypeVar("T", infer_variance=True) _Function = Callable[P, R] _Method = Callable[Concatenate[T, P], R] -_DBFunction: TypeAlias = Callable[Concatenate[PrefectDBInterface, P], R] -_DBMethod: TypeAlias = Callable[Concatenate[T, PrefectDBInterface, P], R] +_DBFunction: TypeAlias = Callable[Concatenate["PrefectDBInterface", P], R] +_DBMethod: TypeAlias = Callable[Concatenate[T, "PrefectDBInterface", P], R] class _ModelDependencies(TypedDict): database_config: Optional[BaseDatabaseConfiguration] - query_components: Optional[BaseQueryComponents] + query_components: Optional["BaseQueryComponents"] orm: Optional[BaseORMConfiguration] - interface_class: Optional[type[PrefectDBInterface]] + interface_class: Optional[type["PrefectDBInterface"]] MODELS_DEPENDENCIES: _ModelDependencies = { @@ -72,13 +70,19 @@ class _ModelDependencies(TypedDict): } -def provide_database_interface() -> PrefectDBInterface: +def provide_database_interface() -> "PrefectDBInterface": """ Get the current Prefect REST API database interface. If components of the interface are not set, defaults will be inferred based on the dialect of the connection URL. """ + from prefect.server.database.interface import PrefectDBInterface + from prefect.server.database.query_components import ( + AioSqliteQueryComponents, + AsyncPostgresQueryComponents, + ) + connection_url = PREFECT_API_DATABASE_CONNECTION_URL.value() database_config = MODELS_DEPENDENCIES.get("database_config") @@ -384,7 +388,7 @@ def temporary_database_config( @contextmanager def temporary_query_components( - tmp_queries: Optional[BaseQueryComponents], + tmp_queries: Optional["BaseQueryComponents"], ) -> Generator[None, object, None]: """ Temporarily override the Prefect REST API database query components. @@ -426,7 +430,7 @@ def temporary_orm_config( @contextmanager def temporary_interface_class( - tmp_interface_class: Optional[type[PrefectDBInterface]], + tmp_interface_class: Optional[type["PrefectDBInterface"]], ) -> Generator[None, object, None]: """ Temporarily override the Prefect REST API interface class When the context is closed, @@ -447,9 +451,9 @@ def temporary_interface_class( @contextmanager def temporary_database_interface( tmp_database_config: Optional[BaseDatabaseConfiguration] = None, - tmp_queries: Optional[BaseQueryComponents] = None, + tmp_queries: Optional["BaseQueryComponents"] = None, tmp_orm_config: Optional[BaseORMConfiguration] = None, - tmp_interface_class: Optional[type[PrefectDBInterface]] = None, + tmp_interface_class: Optional[type["PrefectDBInterface"]] = None, ) -> Generator[None, object, None]: """ Temporarily override the Prefect REST API database interface. @@ -485,7 +489,7 @@ def set_database_config(database_config: Optional[BaseDatabaseConfiguration]) -> MODELS_DEPENDENCIES["database_config"] = database_config -def set_query_components(query_components: Optional[BaseQueryComponents]) -> None: +def set_query_components(query_components: Optional["BaseQueryComponents"]) -> None: """Set Prefect REST API query components.""" MODELS_DEPENDENCIES["query_components"] = query_components @@ -495,6 +499,6 @@ def set_orm_config(orm_config: Optional[BaseORMConfiguration]) -> None: MODELS_DEPENDENCIES["orm"] = orm_config -def set_interface_class(interface_class: Optional[type[PrefectDBInterface]]) -> None: +def set_interface_class(interface_class: Optional[type["PrefectDBInterface"]]) -> None: """Set Prefect REST API interface class.""" MODELS_DEPENDENCIES["interface_class"] = interface_class diff --git a/src/prefect/server/database/interface.py b/src/prefect/server/database/interface.py index d4a427abdb504..239064b7fe144 100644 --- a/src/prefect/server/database/interface.py +++ b/src/prefect/server/database/interface.py @@ -1,34 +1,55 @@ -import datetime +from collections.abc import Hashable from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from typing_extensions import TypeAlias from prefect.server.database import orm_models from prefect.server.database.alembic_commands import alembic_downgrade, alembic_upgrade from prefect.server.database.configurations import BaseDatabaseConfiguration -from prefect.server.database.query_components import BaseQueryComponents from prefect.server.utilities.database import get_dialect from prefect.utilities.asyncutils import run_sync_in_worker_thread +if TYPE_CHECKING: + from prefect.server.database.query_components import BaseQueryComponents + +_UniqueKey: TypeAlias = tuple[Hashable, ...] + class DBSingleton(type): """Ensures that only one database interface is created per unique key""" - _instances = dict() + _instances: dict[ + tuple[str, _UniqueKey, _UniqueKey, _UniqueKey], "DBSingleton" + ] = dict() - def __call__(cls, *args, **kwargs): - unique_key = ( + def __call__( + cls, + *args: Any, + database_config: BaseDatabaseConfiguration, + query_components: "BaseQueryComponents", + orm: orm_models.BaseORMConfiguration, + **kwargs: Any, + ) -> "DBSingleton": + instance_key = ( cls.__name__, - kwargs["database_config"]._unique_key(), - kwargs["query_components"]._unique_key(), - kwargs["orm"]._unique_key(), + database_config.unique_key(), + query_components.unique_key(), + orm.unique_key(), ) - if unique_key not in cls._instances: - cls._instances[unique_key] = super(DBSingleton, cls).__call__( - *args, **kwargs + try: + instance = cls._instances[instance_key] + except KeyError: + instance = cls._instances[instance_key] = super().__call__( + *args, + database_config=database_config, + query_components=query_components, + orm=orm, + **kwargs, ) - return cls._instances[unique_key] + return instance class PrefectDBInterface(metaclass=DBSingleton): @@ -44,30 +65,30 @@ class PrefectDBInterface(metaclass=DBSingleton): def __init__( self, database_config: BaseDatabaseConfiguration, - query_components: BaseQueryComponents, + query_components: "BaseQueryComponents", orm: orm_models.BaseORMConfiguration, ): self.database_config = database_config self.queries = query_components self.orm = orm - async def create_db(self): + async def create_db(self) -> None: """Create the database""" await self.run_migrations_upgrade() - async def drop_db(self): + async def drop_db(self) -> None: """Drop the database""" await self.run_migrations_downgrade(revision="base") - async def run_migrations_upgrade(self): + async def run_migrations_upgrade(self) -> None: """Run all upgrade migrations""" await run_sync_in_worker_thread(alembic_upgrade) - async def run_migrations_downgrade(self, revision: str = "-1"): + async def run_migrations_downgrade(self, revision: str = "-1") -> None: """Run all downgrade migrations""" await run_sync_in_worker_thread(alembic_downgrade, revision=revision) - async def is_db_connectable(self): + async def is_db_connectable(self) -> bool: """ Returns boolean indicating if the database is connectable. This method is used to determine if the server is ready to accept requests. @@ -87,7 +108,7 @@ async def engine(self) -> AsyncEngine: return engine - async def session(self): + async def session(self) -> AsyncSession: """ Provides a SQLAlchemy session. """ @@ -117,292 +138,192 @@ async def session_context( yield session @property - def dialect(self) -> sa.engine.Dialect: + def dialect(self) -> type[sa.engine.Dialect]: return get_dialect(self.database_config.connection_url) @property - def Base(self): + def Base(self) -> type[orm_models.Base]: """Base class for orm models""" return orm_models.Base @property - def Flow(self): + def Flow(self) -> type[orm_models.Flow]: """A flow orm model""" return orm_models.Flow @property - def FlowRun(self): + def FlowRun(self) -> type[orm_models.FlowRun]: """A flow run orm model""" return orm_models.FlowRun @property - def FlowRunState(self): + def FlowRunState(self) -> type[orm_models.FlowRunState]: """A flow run state orm model""" return orm_models.FlowRunState @property - def TaskRun(self): + def TaskRun(self) -> type[orm_models.TaskRun]: """A task run orm model""" return orm_models.TaskRun @property - def TaskRunState(self): + def TaskRunState(self) -> type[orm_models.TaskRunState]: """A task run state orm model""" return orm_models.TaskRunState @property - def Artifact(self): + def Artifact(self) -> type[orm_models.Artifact]: """An artifact orm model""" return orm_models.Artifact @property - def ArtifactCollection(self): + def ArtifactCollection(self) -> type[orm_models.ArtifactCollection]: """An artifact collection orm model""" return orm_models.ArtifactCollection @property - def TaskRunStateCache(self): + def TaskRunStateCache(self) -> type[orm_models.TaskRunStateCache]: """A task run state cache orm model""" return orm_models.TaskRunStateCache @property - def Deployment(self): + def Deployment(self) -> type[orm_models.Deployment]: """A deployment orm model""" return orm_models.Deployment @property - def DeploymentSchedule(self): + def DeploymentSchedule(self) -> type[orm_models.DeploymentSchedule]: """A deployment schedule orm model""" return orm_models.DeploymentSchedule @property - def SavedSearch(self): + def SavedSearch(self) -> type[orm_models.SavedSearch]: """A saved search orm model""" return orm_models.SavedSearch @property - def WorkPool(self): + def WorkPool(self) -> type[orm_models.WorkPool]: """A work pool orm model""" return orm_models.WorkPool @property - def Worker(self): + def Worker(self) -> type[orm_models.Worker]: """A worker process orm model""" return orm_models.Worker @property - def Log(self): + def Log(self) -> type[orm_models.Log]: """A log orm model""" return orm_models.Log @property - def ConcurrencyLimit(self): + def ConcurrencyLimit(self) -> type[orm_models.ConcurrencyLimit]: """A concurrency model""" return orm_models.ConcurrencyLimit @property - def ConcurrencyLimitV2(self): + def ConcurrencyLimitV2(self) -> type[orm_models.ConcurrencyLimitV2]: """A v2 concurrency model""" return orm_models.ConcurrencyLimitV2 @property - def CsrfToken(self): + def CsrfToken(self) -> type[orm_models.CsrfToken]: """A csrf token model""" return orm_models.CsrfToken @property - def WorkQueue(self): + def WorkQueue(self) -> type[orm_models.WorkQueue]: """A work queue model""" return orm_models.WorkQueue @property - def Agent(self): + def Agent(self) -> type[orm_models.Agent]: """An agent model""" return orm_models.Agent @property - def BlockType(self): + def BlockType(self) -> type[orm_models.BlockType]: """A block type model""" return orm_models.BlockType @property - def BlockSchema(self): + def BlockSchema(self) -> type[orm_models.BlockSchema]: """A block schema model""" return orm_models.BlockSchema @property - def BlockSchemaReference(self): + def BlockSchemaReference(self) -> type[orm_models.BlockSchemaReference]: """A block schema reference model""" return orm_models.BlockSchemaReference @property - def BlockDocument(self): + def BlockDocument(self) -> type[orm_models.BlockDocument]: """A block document model""" return orm_models.BlockDocument @property - def BlockDocumentReference(self): + def BlockDocumentReference(self) -> type[orm_models.BlockDocumentReference]: """A block document reference model""" return orm_models.BlockDocumentReference @property - def FlowRunNotificationPolicy(self): + def FlowRunNotificationPolicy(self) -> type[orm_models.FlowRunNotificationPolicy]: """A flow run notification policy model""" return orm_models.FlowRunNotificationPolicy @property - def FlowRunNotificationQueue(self): + def FlowRunNotificationQueue(self) -> type[orm_models.FlowRunNotificationQueue]: """A flow run notification queue model""" return orm_models.FlowRunNotificationQueue @property - def Configuration(self): + def Configuration(self) -> type[orm_models.Configuration]: """An configuration model""" return orm_models.Configuration @property - def Variable(self): + def Variable(self) -> type[orm_models.Variable]: """A variable model""" return orm_models.Variable @property - def FlowRunInput(self): + def FlowRunInput(self) -> type[orm_models.FlowRunInput]: """A flow run input model""" return orm_models.FlowRunInput @property - def Automation(self): + def Automation(self) -> type[orm_models.Automation]: """An automation model""" return orm_models.Automation @property - def AutomationBucket(self): + def AutomationBucket(self) -> type[orm_models.AutomationBucket]: """An automation bucket model""" return orm_models.AutomationBucket @property - def AutomationRelatedResource(self): + def AutomationRelatedResource(self) -> type[orm_models.AutomationRelatedResource]: """An automation related resource model""" return orm_models.AutomationRelatedResource @property - def CompositeTriggerChildFiring(self): + def CompositeTriggerChildFiring( + self, + ) -> type[orm_models.CompositeTriggerChildFiring]: """A model capturing a composite trigger's child firing""" return orm_models.CompositeTriggerChildFiring @property - def AutomationEventFollower(self): + def AutomationEventFollower(self) -> type[orm_models.AutomationEventFollower]: """A model capturing one event following another event""" return orm_models.AutomationEventFollower @property - def Event(self): + def Event(self) -> type[orm_models.Event]: """An event model""" return orm_models.Event @property - def EventResource(self): + def EventResource(self) -> type[orm_models.EventResource]: """An event resource model""" return orm_models.EventResource - - @property - def deployment_unique_upsert_columns(self): - """Unique columns for upserting a Deployment""" - return self.orm.deployment_unique_upsert_columns - - @property - def concurrency_limit_unique_upsert_columns(self): - """Unique columns for upserting a ConcurrencyLimit""" - return self.orm.concurrency_limit_unique_upsert_columns - - @property - def flow_run_unique_upsert_columns(self): - """Unique columns for upserting a FlowRun""" - return self.orm.flow_run_unique_upsert_columns - - @property - def artifact_collection_unique_upsert_columns(self): - """Unique columns for upserting an ArtifactCollection""" - return self.orm.artifact_collection_unique_upsert_columns - - @property - def block_type_unique_upsert_columns(self): - """Unique columns for upserting a BlockType""" - return self.orm.block_type_unique_upsert_columns - - @property - def block_schema_unique_upsert_columns(self): - """Unique columns for upserting a BlockSchema""" - return self.orm.block_schema_unique_upsert_columns - - @property - def flow_unique_upsert_columns(self): - """Unique columns for upserting a Flow""" - return self.orm.flow_unique_upsert_columns - - @property - def saved_search_unique_upsert_columns(self): - """Unique columns for upserting a SavedSearch""" - return self.orm.saved_search_unique_upsert_columns - - @property - def task_run_unique_upsert_columns(self): - """Unique columns for upserting a TaskRun""" - return self.orm.task_run_unique_upsert_columns - - @property - def block_document_unique_upsert_columns(self): - """Unique columns for upserting a BlockDocument""" - return self.orm.block_document_unique_upsert_columns - - def insert(self, model): - """INSERTs a model into the database""" - return self.queries.insert(model) - - def make_timestamp_intervals( - self, - start_time: datetime.datetime, - end_time: datetime.datetime, - interval: datetime.timedelta, - ): - return self.queries.make_timestamp_intervals(start_time, end_time, interval) - - def set_state_id_on_inserted_flow_runs_statement( - self, inserted_flow_run_ids, insert_flow_run_states - ): - """Given a list of flow run ids and associated states, set the state_id - to the appropriate state for all flow runs""" - return self.queries.set_state_id_on_inserted_flow_runs_statement( - orm_models.FlowRun, - orm_models.FlowRunState, - inserted_flow_run_ids, - insert_flow_run_states, - ) - - @property - def uses_json_strings(self): - return self.queries.uses_json_strings - - def cast_to_json(self, json_obj): - return self.queries.cast_to_json(json_obj) - - def build_json_object(self, *args): - return self.queries.build_json_object(*args) - - def json_arr_agg(self, json_array): - return self.queries.json_arr_agg(json_array) - - async def get_flow_run_notifications_from_queue( - self, session: sa.orm.Session, limit: int - ): - return await self.queries.get_flow_run_notifications_from_queue( - session=session, limit=limit - ) - - async def read_configuration_value(self, session: AsyncSession, key: str): - """Read a configuration value""" - return await self.queries.read_configuration_value(session=session, key=key) - - def clear_configuration_value_cache_for_key(self, key: str): - """Removes a configuration key from the cache.""" - return self.queries.clear_configuration_value_cache_for_key(key=key) diff --git a/src/prefect/server/database/orm_models.py b/src/prefect/server/database/orm_models.py index 7cf8b3fcd2eaa..ef504438ca083 100644 --- a/src/prefect/server/database/orm_models.py +++ b/src/prefect/server/database/orm_models.py @@ -3,14 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Hashable, Iterable from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - Optional, - Union, -) +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union import pendulum import sqlalchemy as sa @@ -55,9 +48,12 @@ from prefect.server.utilities.encryption import decrypt_fernet, encrypt_fernet from prefect.utilities.names import generate_slug +if TYPE_CHECKING: + DateTime = pendulum.DateTime + # for 'plain JSON' columns, use the postgresql variant (which comes with an # extra operator) and fall back to the generic JSON variant for SQLite -sa_JSON = postgresql.JSON().with_variant(sa.JSON(), "sqlite") +sa_JSON: postgresql.JSON = postgresql.JSON().with_variant(sa.JSON(), "sqlite") class Base(DeclarativeBase): @@ -824,7 +820,7 @@ def job_variables(self) -> Mapped[dict[str, Any]]: paused: Mapped[bool] = mapped_column(server_default="0", default=False, index=True) schedules: Mapped[list["DeploymentSchedule"]] = relationship( - lazy="selectin", order_by=sa.desc(sa.text("updated")) + lazy="selectin", order_by=lambda: DeploymentSchedule.updated.desc() ) # deprecated in favor of `concurrency_limit_id` FK @@ -1075,7 +1071,7 @@ class BlockDocumentReference(Base): class Configuration(Base): key: Mapped[str] = mapped_column(index=True) - value: Mapped[Dict[str, Any]] = mapped_column(JSON) + value: Mapped[dict[str, Any]] = mapped_column(JSON) __table_args__: Any = (sa.UniqueConstraint("key"),) @@ -1119,7 +1115,7 @@ class WorkQueue(Base): lazy="selectin", foreign_keys=[work_pool_id] ) - __table_args__ = ( + __table_args__: ClassVar[Any] = ( sa.UniqueConstraint("work_pool_id", "name"), sa.Index("ix_work_queue__work_pool_id_priority", "work_pool_id", "priority"), sa.Index("trgm_ix_work_queue_name", "name", postgresql_using="gin").ddl_if( @@ -1496,7 +1492,7 @@ class BaseORMConfiguration(ABC): Use with caution. """ - def _unique_key(self) -> tuple[Hashable, ...]: + def unique_key(self) -> tuple[Hashable, ...]: """ Returns a key used to determine whether to instantiate a new DB interface. """ diff --git a/src/prefect/server/database/query_components.py b/src/prefect/server/database/query_components.py index 61bc7a9238c41..0f4fe13fa446b 100644 --- a/src/prefect/server/database/query_components.py +++ b/src/prefect/server/database/query_components.py @@ -1,14 +1,15 @@ import datetime -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Hashable, Iterable, Sequence +from functools import cached_property from typing import ( TYPE_CHECKING, - Dict, - Hashable, - List, + Any, + ClassVar, + Literal, + NamedTuple, Optional, - Sequence, - Tuple, Union, cast, ) @@ -16,26 +17,60 @@ import pendulum import sqlalchemy as sa -from cachetools import TTLCache +from cachetools import Cache, TTLCache from jinja2 import Environment, PackageLoader, select_autoescape +from sqlalchemy import orm from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.type_api import TypeEngine +from typing_extensions import TypeVar from prefect.server import models, schemas from prefect.server.database import orm_models +from prefect.server.database.dependencies import db_injector +from prefect.server.database.interface import PrefectDBInterface from prefect.server.exceptions import FlowRunGraphTooLarge, ObjectNotFoundError from prefect.server.schemas.graph import Edge, Graph, GraphArtifact, GraphState, Node +from prefect.server.schemas.states import StateType from prefect.server.utilities.database import UUID as UUIDTypeDecorator -from prefect.server.utilities.database import Timestamp +from prefect.server.utilities.database import Timestamp, bindparams_from_clause + +T = TypeVar("T", infer_variance=True) + + +class FlowRunNotificationsFromQueue(NamedTuple): + queue_id: UUID + flow_run_notification_policy_id: UUID + flow_run_notification_policy_message_template: Optional[str] + block_document_id: UUID + flow_id: UUID + flow_name: str + flow_run_id: UUID + flow_run_name: str + flow_run_parameters: dict[str, Any] + flow_run_state_type: StateType + flow_run_state_name: str + flow_run_state_timestamp: pendulum.DateTime + flow_run_state_message: Optional[str] + + +class FlowRunGraphV2Node(NamedTuple): + kind: Literal["flow-run", "task-run"] + id: UUID + label: str + state_type: StateType + start_time: pendulum.DateTime + end_time: Optional[pendulum.DateTime] + parent_ids: Optional[list[UUID]] + child_ids: Optional[list[UUID]] + encapsulating_ids: Optional[list[UUID]] -if TYPE_CHECKING: - from prefect.server.database.interface import PrefectDBInterface ONE_HOUR = 60 * 60 -jinja_env = Environment( +jinja_env: Environment = Environment( loader=PackageLoader("prefect.server.database", package_path="sql"), autoescape=select_autoescape(), trim_blocks=True, @@ -47,9 +82,11 @@ class BaseQueryComponents(ABC): Abstract base class used to inject dialect-specific SQL operations into Prefect. """ - CONFIGURATION_CACHE = TTLCache(maxsize=100, ttl=ONE_HOUR) + _configuration_cache: ClassVar[Cache[str, dict[str, Any]]] = TTLCache( + maxsize=100, ttl=ONE_HOUR + ) - def _unique_key(self) -> Tuple[Hashable, ...]: + def unique_key(self) -> tuple[Hashable, ...]: """ Returns a key used to determine whether to instantiate a new DB interface. """ @@ -58,25 +95,30 @@ def _unique_key(self) -> Tuple[Hashable, ...]: # --- dialect-specific SqlAlchemy bindings @abstractmethod - def insert(self, obj) -> Union[postgresql.Insert, sqlite.Insert]: + def insert( + self, obj: type[orm_models.Base] + ) -> Union[postgresql.Insert, sqlite.Insert]: """dialect-specific insert statement""" # --- dialect-specific JSON handling - @abstractproperty + @property + @abstractmethod def uses_json_strings(self) -> bool: """specifies whether the configured dialect returns JSON as strings""" @abstractmethod - def cast_to_json(self, json_obj): + def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: """casts to JSON object if necessary""" @abstractmethod - def build_json_object(self, *args): + def build_json_object( + self, *args: Union[str, sa.ColumnElement[Any]] + ) -> sa.ColumnElement[Any]: """builds a JSON object from sequential key-value pairs""" @abstractmethod - def json_arr_agg(self, json_array): + def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: """aggregates a JSON array""" # --- dialect-optimized subqueries @@ -84,34 +126,33 @@ def json_arr_agg(self, json_array): @abstractmethod def make_timestamp_intervals( self, - start_time: datetime.datetime, - end_time: datetime.datetime, + start_time: pendulum.DateTime, + end_time: pendulum.DateTime, interval: datetime.timedelta, - ): + ) -> sa.Select[tuple[pendulum.DateTime, pendulum.DateTime]]: ... @abstractmethod def set_state_id_on_inserted_flow_runs_statement( self, - fr_model, - frs_model, - inserted_flow_run_ids, - insert_flow_run_states, - ): + inserted_flow_run_ids: Sequence[UUID], + insert_flow_run_states: Iterable[dict[str, Any]], + ) -> sa.Update: ... @abstractmethod async def get_flow_run_notifications_from_queue( self, session: AsyncSession, limit: int - ): + ) -> Sequence[FlowRunNotificationsFromQueue]: """Database-specific implementation of reading notifications from the queue and deleting them""" + @db_injector async def queue_flow_run_notifications( self, - session: sa.orm.session, + db: PrefectDBInterface, + session: AsyncSession, flow_run: Union[schemas.core.FlowRun, orm_models.FlowRun], - db: "PrefectDBInterface", - ): + ) -> None: """Database-specific implementation of queueing notifications for a flow run""" def as_array(elems: Sequence[str]) -> sa.ColumnElement[Sequence[str]]: @@ -120,35 +161,36 @@ def as_array(elems: Sequence[str]) -> sa.ColumnElement[Sequence[str]]: if TYPE_CHECKING: assert flow_run.state_name is not None + FlowRunNotificationQueue = db.FlowRunNotificationQueue + FlowRunNotificationPolicy = db.FlowRunNotificationPolicy + # insert a pair into the notification queue - stmt = db.insert(orm_models.FlowRunNotificationQueue).from_select( + stmt = self.insert(FlowRunNotificationQueue).from_select( [ - orm_models.FlowRunNotificationQueue.flow_run_notification_policy_id, - orm_models.FlowRunNotificationQueue.flow_run_state_id, + FlowRunNotificationQueue.flow_run_notification_policy_id, + FlowRunNotificationQueue.flow_run_state_id, ], # ... by selecting from any notification policy that matches the criteria sa.select( - orm_models.FlowRunNotificationPolicy.id, + FlowRunNotificationPolicy.id, sa.cast(sa.literal(str(flow_run.state_id)), UUIDTypeDecorator), ) - .select_from(orm_models.FlowRunNotificationPolicy) + .select_from(FlowRunNotificationPolicy) .where( sa.and_( # the policy is active - orm_models.FlowRunNotificationPolicy.is_active.is_(True), + FlowRunNotificationPolicy.is_active.is_(True), # the policy state names aren't set or match the current state name sa.or_( - orm_models.FlowRunNotificationPolicy.state_names == [], - orm_models.FlowRunNotificationPolicy.state_names.has_any( + FlowRunNotificationPolicy.state_names == [], + FlowRunNotificationPolicy.state_names.has_any( as_array([flow_run.state_name]) ), ), # the policy tags aren't set, or the tags match the flow run tags sa.or_( - orm_models.FlowRunNotificationPolicy.tags == [], - orm_models.FlowRunNotificationPolicy.tags.has_any( - as_array(flow_run.tags) - ), + FlowRunNotificationPolicy.tags == [], + FlowRunNotificationPolicy.tags.has_any(as_array(flow_run.tags)), ), ) ), @@ -158,12 +200,14 @@ def as_array(elems: Sequence[str]) -> sa.ColumnElement[Sequence[str]]: ) await session.execute(stmt) + @db_injector def get_scheduled_flow_runs_from_work_queues( self, + db: PrefectDBInterface, limit_per_queue: Optional[int] = None, - work_queue_ids: Optional[List[UUID]] = None, - scheduled_before: Optional[datetime.datetime] = None, - ): + work_queue_ids: Optional[list[UUID]] = None, + scheduled_before: Optional[pendulum.DateTime] = None, + ) -> sa.Select[tuple[orm_models.FlowRun, UUID]]: """ Returns all scheduled runs in work queues, subject to provided parameters. @@ -172,32 +216,31 @@ def get_scheduled_flow_runs_from_work_queues( will return only the flow run because it grabs the first result. """ + FlowRun, WorkQueue = db.FlowRun, db.WorkQueue + # get any work queues that have a concurrency limit, and compute available # slots as their limit less the number of running flows concurrency_queues = ( sa.select( - orm_models.WorkQueue.id, + WorkQueue.id, sa.func.greatest( 0, - orm_models.WorkQueue.concurrency_limit - - sa.func.count(orm_models.FlowRun.id), + WorkQueue.concurrency_limit - sa.func.count(FlowRun.id), ).label("available_slots"), ) - .select_from(orm_models.WorkQueue) + .select_from(WorkQueue) .join( - orm_models.FlowRun, + FlowRun, sa.and_( - self._flow_run_work_queue_join_clause( - orm_models.FlowRun, orm_models.WorkQueue - ), - orm_models.FlowRun.state_type.in_( - ["RUNNING", "PENDING", "CANCELLING"] + FlowRun.work_queue_name == WorkQueue.name, + FlowRun.state_type.in_( + (StateType.RUNNING, StateType.PENDING, StateType.CANCELLING) ), ), isouter=True, ) - .where(orm_models.WorkQueue.concurrency_limit.is_not(None)) - .group_by(orm_models.WorkQueue.id) + .where(WorkQueue.concurrency_limit.is_not(None)) + .group_by(WorkQueue.id) .cte("concurrency_queues") ) @@ -215,19 +258,18 @@ def get_scheduled_flow_runs_from_work_queues( query = ( # return a flow run and work queue id sa.select( - sa.orm.aliased(orm_models.FlowRun, scheduled_flow_runs), - orm_models.WorkQueue.id.label("wq_id"), + orm.aliased(FlowRun, scheduled_flow_runs), WorkQueue.id.label("wq_id") ) - .select_from(orm_models.WorkQueue) + .select_from(WorkQueue) .join( concurrency_queues, - orm_models.WorkQueue.id == concurrency_queues.c.id, + WorkQueue.id == concurrency_queues.c.id, isouter=True, ) .join(scheduled_flow_runs, join_criteria) .where( - orm_models.WorkQueue.is_paused.is_(False), - orm_models.WorkQueue.id.in_(work_queue_ids) if work_queue_ids else True, + WorkQueue.is_paused.is_(False), + WorkQueue.id.in_(work_queue_ids) if work_queue_ids else sa.true(), ) .order_by( scheduled_flow_runs.c.next_scheduled_start_time, @@ -237,77 +279,73 @@ def get_scheduled_flow_runs_from_work_queues( return query + @db_injector def _get_scheduled_flow_runs_join( self, - work_queue_query, + db: PrefectDBInterface, + work_queue_query: sa.CTE, limit_per_queue: Optional[int], - scheduled_before: Optional[datetime.datetime], - ): + scheduled_before: Optional[pendulum.DateTime], + ) -> tuple[sa.FromClause, sa.ColumnExpressionArgument[bool]]: """Used by self.get_scheduled_flow_runs_from_work_queue, allowing just this function to be changed on a per-dialect basis""" + FlowRun = db.FlowRun + # precompute for readability scheduled_before_clause = ( - orm_models.FlowRun.next_scheduled_start_time <= scheduled_before + FlowRun.next_scheduled_start_time <= scheduled_before if scheduled_before is not None - else True + else sa.true() ) # get scheduled flow runs with lateral join where the limit is the # available slots per queue scheduled_flow_runs = ( - sa.select(orm_models.FlowRun) + sa.select(FlowRun) .where( - self._flow_run_work_queue_join_clause( - orm_models.FlowRun, orm_models.WorkQueue - ), - orm_models.FlowRun.state_type == "SCHEDULED", + FlowRun.work_queue_name == db.WorkQueue.name, + FlowRun.state_type == StateType.SCHEDULED, scheduled_before_clause, ) .with_for_update(skip_locked=True) # priority given to runs with earlier next_scheduled_start_time - .order_by(orm_models.FlowRun.next_scheduled_start_time) + .order_by(FlowRun.next_scheduled_start_time) # if null, no limit will be applied .limit(sa.func.least(limit_per_queue, work_queue_query.c.available_slots)) .lateral("scheduled_flow_runs") ) # Perform a cross-join - join_criteria = sa.literal(True) + join_criteria = sa.true() return scheduled_flow_runs, join_criteria - def _flow_run_work_queue_join_clause(self, flow_run, work_queue): - """ - On clause for for joining flow runs to work queues - - Used by self.get_scheduled_flow_runs_from_work_queue, allowing just - this function to be changed on a per-dialect basis - """ - return sa.and_(flow_run.work_queue_name == work_queue.name) - # ------------------------------------------------------- # Workers # ------------------------------------------------------- - @abstractproperty - def _get_scheduled_flow_runs_from_work_pool_template_path(self): + @property + @abstractmethod + def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: """ Template for the query to get scheduled flow runs from a work pool """ + @db_injector async def get_scheduled_flow_runs_from_work_pool( self, - session, + db: PrefectDBInterface, + session: AsyncSession, limit: Optional[int] = None, worker_limit: Optional[int] = None, queue_limit: Optional[int] = None, - work_pool_ids: Optional[List[UUID]] = None, - work_queue_ids: Optional[List[UUID]] = None, - scheduled_before: Optional[datetime.datetime] = None, - scheduled_after: Optional[datetime.datetime] = None, + work_pool_ids: Optional[list[UUID]] = None, + work_queue_ids: Optional[list[UUID]] = None, + scheduled_before: Optional[pendulum.DateTime] = None, + scheduled_after: Optional[pendulum.DateTime] = None, respect_queue_priorities: bool = False, - ) -> List[schemas.responses.WorkerFlowRunResponse]: + ) -> list[schemas.responses.WorkerFlowRunResponse]: template = jinja_env.get_template( self._get_scheduled_flow_runs_from_work_pool_template_path ) @@ -322,7 +360,7 @@ async def get_scheduled_flow_runs_from_work_pool( ) ) - bindparams = [] + bindparams: list[sa.BindParameter[Any]] = [] if scheduled_before: bindparams.append( @@ -365,153 +403,37 @@ async def get_scheduled_flow_runs_from_work_pool( queue_limit=1000 if queue_limit is None else queue_limit, ) + FlowRun = db.FlowRun orm_query = ( sa.select( - sa.column("run_work_pool_id"), - sa.column("run_work_queue_id"), - orm_models.FlowRun, + sa.column("run_work_pool_id", UUIDTypeDecorator), + sa.column("run_work_queue_id", UUIDTypeDecorator), + FlowRun, ) .from_statement(query) # indicate that the state relationship isn't being loaded - .options(sa.orm.noload(orm_models.FlowRun.state)) + .options(orm.noload(FlowRun.state)) ) - result = await session.execute(orm_query) + result: sa.Result[ + tuple[UUID, UUID, orm_models.FlowRun] + ] = await session.execute(orm_query) return [ schemas.responses.WorkerFlowRunResponse( - work_pool_id=r.run_work_pool_id, - work_queue_id=r.run_work_queue_id, + work_pool_id=run_work_pool_id, + work_queue_id=run_work_queue_id, flow_run=schemas.core.FlowRun.model_validate( - r.FlowRun, from_attributes=True + flow_run, from_attributes=True ), ) - for r in result + for (run_work_pool_id, run_work_queue_id, flow_run) in result.t ] - async def read_block_documents( - self, - session: sa.orm.Session, - block_document_filter: Optional[schemas.filters.BlockDocumentFilter] = None, - block_type_filter: Optional[schemas.filters.BlockTypeFilter] = None, - block_schema_filter: Optional[schemas.filters.BlockSchemaFilter] = None, - include_secrets: bool = False, - offset: Optional[int] = None, - limit: Optional[int] = None, - ): - # if no filter is provided, one is created that excludes anonymous blocks - if block_document_filter is None: - block_document_filter = schemas.filters.BlockDocumentFilter( - is_anonymous=schemas.filters.BlockDocumentFilterIsAnonymous(eq_=False) - ) - - # --- Query for Parent Block Documents - # begin by building a query for only those block documents that are selected - # by the provided filters - filtered_block_documents_query = sa.select(orm_models.BlockDocument.id).where( - block_document_filter.as_sql_filter() - ) - - if block_type_filter is not None: - block_type_exists_clause = sa.select(orm_models.BlockType).where( - orm_models.BlockType.id == orm_models.BlockDocument.block_type_id, - block_type_filter.as_sql_filter(), - ) - filtered_block_documents_query = filtered_block_documents_query.where( - block_type_exists_clause.exists() - ) - - if block_schema_filter is not None: - block_schema_exists_clause = sa.select(orm_models.BlockSchema).where( - orm_models.BlockSchema.id == orm_models.BlockDocument.block_schema_id, - block_schema_filter.as_sql_filter(), - ) - filtered_block_documents_query = filtered_block_documents_query.where( - block_schema_exists_clause.exists() - ) - - if offset is not None: - filtered_block_documents_query = filtered_block_documents_query.offset( - offset - ) - - if limit is not None: - filtered_block_documents_query = filtered_block_documents_query.limit(limit) - - filtered_block_documents_query = filtered_block_documents_query.cte( - "filtered_block_documents" - ) - - # --- Query for Referenced Block Documents - # next build a recursive query for (potentially nested) block documents - # that reference the filtered block documents - block_document_references_query = ( - sa.select(orm_models.BlockDocumentReference) - .filter( - orm_models.BlockDocumentReference.parent_block_document_id.in_( - sa.select(filtered_block_documents_query.c.id) - ) - ) - .cte("block_document_references", recursive=True) - ) - block_document_references_join = sa.select( - orm_models.BlockDocumentReference - ).join( - block_document_references_query, - orm_models.BlockDocumentReference.parent_block_document_id - == block_document_references_query.c.reference_block_document_id, - ) - recursive_block_document_references_cte = ( - block_document_references_query.union_all(block_document_references_join) - ) - - # --- Final Query for All Block Documents - # build a query that unions: - # - the filtered block documents - # - with any block documents that are discovered as (potentially nested) references - all_block_documents_query = sa.union_all( - # first select the parent block - sa.select( - orm_models.BlockDocument, - sa.null().label("reference_name"), - sa.null().label("reference_parent_block_document_id"), - ) - .select_from(orm_models.BlockDocument) - .where( - orm_models.BlockDocument.id.in_( - sa.select(filtered_block_documents_query.c.id) - ) - ), - # - # then select any referenced blocks - sa.select( - orm_models.BlockDocument, - recursive_block_document_references_cte.c.name, - recursive_block_document_references_cte.c.parent_block_document_id, - ) - .select_from(orm_models.BlockDocument) - .join( - recursive_block_document_references_cte, - orm_models.BlockDocument.id - == recursive_block_document_references_cte.c.reference_block_document_id, - ), - ).cte("all_block_documents_query") - - # the final union query needs to be `aliased` for proper ORM unpacking - # and also be sorted - return ( - sa.select( - sa.orm.aliased(orm_models.BlockDocument, all_block_documents_query), - all_block_documents_query.c.reference_name, - all_block_documents_query.c.reference_parent_block_document_id, - ) - .select_from(all_block_documents_query) - .order_by(all_block_documents_query.c.name) - ) - + @db_injector async def read_configuration_value( - self, session: sa.orm.Session, key: str - ) -> Optional[Dict]: + self, db: PrefectDBInterface, session: AsyncSession, key: str + ) -> Optional[dict[str, Any]]: """ Read a configuration value by key. @@ -521,69 +443,156 @@ async def read_configuration_value( The main use of configurations is encrypting blocks, this speeds up nested block document queries. """ + Configuration = db.Configuration + value = None try: - return self.CONFIGURATION_CACHE[key] + value = self._configuration_cache[key] except KeyError: - query = sa.select(orm_models.Configuration).where( - orm_models.Configuration.key == key - ) - result = await session.execute(query) - configuration = result.scalar() - if configuration is not None: - self.CONFIGURATION_CACHE[key] = configuration.value - return configuration.value - return configuration - - def clear_configuration_value_cache_for_key(self, key: str): + query = sa.select(Configuration).where(Configuration.key == key) + if (configuration := await session.scalar(query)) is not None: + value = self._configuration_cache[key] = configuration.value + return value + + def clear_configuration_value_cache_for_key(self, key: str) -> None: """Removes a configuration key from the cache.""" - self.CONFIGURATION_CACHE.pop(key, None) + self._configuration_cache.pop(key, None) + + @cached_property + def _flow_run_graph_v2_query(self): + query = self._build_flow_run_graph_v2_query() + param_names = set(bindparams_from_clause(query)) + required = {"flow_run_id", "max_nodes", "since"} + assert param_names >= required, ( + "_build_flow_run_graph_v2_query result is missing required bind params: " + f"{sorted(required - param_names)}" + ) + return query @abstractmethod + def _build_flow_run_graph_v2_query(self) -> sa.Select[FlowRunGraphV2Node]: + """The flow run graph query, per database flavour + + The query must accept the following bind parameters: + + flow_run_id: UUID + since: pendulum.DateTime + max_nodes: int + + """ + + @db_injector async def flow_run_graph_v2( self, + db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID, - since: datetime.datetime, + since: pendulum.DateTime, max_nodes: int, max_artifacts: int, ) -> Graph: """Returns the query that selects all of the nodes and edges for a flow run graph (version 2).""" - ... + FlowRun = db.FlowRun + result = await session.execute( + sa.select( + sa.func.coalesce( + FlowRun.start_time, FlowRun.expected_start_time, type_=Timestamp + ), + FlowRun.end_time, + ).where(FlowRun.id == flow_run_id) + ) + try: + start_time, end_time = result.t.one() + except NoResultFound: + raise ObjectNotFoundError(f"Flow run {flow_run_id} not found") + + query = self._flow_run_graph_v2_query + results = await session.execute( + query, + params=dict(flow_run_id=flow_run_id, since=since, max_nodes=max_nodes + 1), + ) + + graph_artifacts = await self._get_flow_run_graph_artifacts( + db, session, flow_run_id, max_artifacts + ) + graph_states = await self._get_flow_run_graph_states(session, flow_run_id) + + nodes: list[tuple[UUID, Node]] = [] + root_node_ids: list[UUID] = [] + + for row in results.t: + if not row.parent_ids: + root_node_ids.append(row.id) + + nodes.append( + ( + row.id, + Node( + kind=row.kind, + id=row.id, + label=row.label, + state_type=row.state_type, + start_time=row.start_time, + end_time=row.end_time, + parents=[Edge(id=id) for id in row.parent_ids or []], + children=[Edge(id=id) for id in row.child_ids or []], + encapsulating=[ + Edge(id=id) + # ensure encapsulating_ids is deduplicated + for id in dict.fromkeys(row.encapsulating_ids or ()) + ], + artifacts=graph_artifacts.get(row.id, []), + ), + ) + ) + + if len(nodes) > max_nodes: + raise FlowRunGraphTooLarge( + f"The graph of flow run {flow_run_id} has more than " + f"{max_nodes} nodes." + ) + + return Graph( + start_time=start_time, + end_time=end_time, + root_node_ids=root_node_ids, + nodes=nodes, + artifacts=graph_artifacts.get(None, []), + states=graph_states, + ) async def _get_flow_run_graph_artifacts( self, + db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID, max_artifacts: int, - ): + ) -> dict[Optional[UUID], list[GraphArtifact]]: """Get the artifacts for a flow run grouped by task run id. Does not recurse into subflows. Artifacts for the flow run without a task run id are grouped under None. """ + Artifact, ArtifactCollection = db.Artifact, db.ArtifactCollection + query = ( - sa.select( - orm_models.Artifact, - orm_models.ArtifactCollection.id.label("latest_in_collection_id"), - ) - .where( - orm_models.Artifact.flow_run_id == flow_run_id, - orm_models.Artifact.type != "result", - ) + sa.select(Artifact, ArtifactCollection.id.label("latest_in_collection_id")) + .where(Artifact.flow_run_id == flow_run_id, Artifact.type != "result") .join( - orm_models.ArtifactCollection, - (orm_models.ArtifactCollection.key == orm_models.Artifact.key) - & (orm_models.ArtifactCollection.latest_id == orm_models.Artifact.id), + ArtifactCollection, + onclause=sa.and_( + ArtifactCollection.key == Artifact.key, + ArtifactCollection.latest_id == Artifact.id, + ), isouter=True, ) - .order_by(orm_models.Artifact.created.asc()) + .order_by(Artifact.created.asc()) .limit(max_artifacts) ) results = await session.execute(query) - artifacts_by_task = defaultdict(list) - for artifact, latest_in_collection_id in results: + artifacts_by_task: dict[Optional[UUID], list[GraphArtifact]] = defaultdict(list) + for artifact, latest_in_collection_id in results.t: artifacts_by_task[artifact.task_run_id].append( GraphArtifact( id=artifact.id, @@ -597,458 +606,402 @@ async def _get_flow_run_graph_artifacts( ) ) - return artifacts_by_task + return dict(artifacts_by_task) async def _get_flow_run_graph_states( - self, - session: AsyncSession, - flow_run_id: UUID, - ): + self, session: AsyncSession, flow_run_id: UUID + ) -> list[GraphState]: """Get the flow run states for a flow run graph.""" - flow_run_states = await models.flow_run_states.read_flow_run_states( - session=session, flow_run_id=flow_run_id - ) - + states = await models.flow_run_states.read_flow_run_states(session, flow_run_id) return [ - GraphState( - id=state.id, - timestamp=state.timestamp, - type=state.type, - name=state.name, - ) - for state in flow_run_states + GraphState.model_validate(state, from_attributes=True) for state in states ] class AsyncPostgresQueryComponents(BaseQueryComponents): # --- Postgres-specific SqlAlchemy bindings - def insert(self, obj) -> postgresql.Insert: + def insert(self, obj: type[orm_models.Base]) -> postgresql.Insert: return postgresql.insert(obj) # --- Postgres-specific JSON handling @property - def uses_json_strings(self): + def uses_json_strings(self) -> bool: return False - def cast_to_json(self, json_obj): + def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: return json_obj - def build_json_object(self, *args): + def build_json_object( + self, *args: Union[str, sa.ColumnElement[Any]] + ) -> sa.ColumnElement[Any]: return sa.func.jsonb_build_object(*args) - def json_arr_agg(self, json_array): + def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: return sa.func.jsonb_agg(json_array) # --- Postgres-optimized subqueries def make_timestamp_intervals( self, - start_time: datetime.datetime, - end_time: datetime.datetime, + start_time: pendulum.DateTime, + end_time: pendulum.DateTime, interval: datetime.timedelta, - ): - # validate inputs - start_time = pendulum.instance(start_time) - end_time = pendulum.instance(end_time) - assert isinstance(interval, datetime.timedelta) + ) -> sa.Select[tuple[pendulum.DateTime, pendulum.DateTime]]: + dt = sa.func.generate_series( + start_time, end_time, interval, type_=Timestamp() + ).column_valued("dt") return ( sa.select( - sa.literal_column("dt").label("interval_start"), - (sa.literal_column("dt") + interval).label("interval_end"), + dt.label("interval_start"), + sa.type_coerce( + dt + sa.bindparam("interval", interval, type_=sa.Interval()), + type_=Timestamp(), + ).label("interval_end"), ) - .select_from( - sa.func.generate_series(start_time, end_time, interval).alias("dt") - ) - .where(sa.literal_column("dt") < end_time) - # grab at most 500 intervals - .limit(500) + .where(dt < end_time) + .limit(500) # grab at most 500 intervals ) + @db_injector def set_state_id_on_inserted_flow_runs_statement( self, - fr_model, - frs_model, - inserted_flow_run_ids, - insert_flow_run_states, - ): + db: PrefectDBInterface, + inserted_flow_run_ids: Sequence[UUID], + insert_flow_run_states: Iterable[dict[str, Any]], + ) -> sa.Update: """Given a list of flow run ids and associated states, set the state_id to the appropriate state for all flow runs""" # postgres supports `UPDATE ... FROM` syntax + FlowRun, FlowRunState = db.FlowRun, db.FlowRunState stmt = ( - sa.update(fr_model) + sa.update(FlowRun) .where( - fr_model.id.in_(inserted_flow_run_ids), - frs_model.flow_run_id == fr_model.id, - frs_model.id.in_([r["id"] for r in insert_flow_run_states]), + FlowRun.id.in_(inserted_flow_run_ids), + FlowRunState.flow_run_id == FlowRun.id, + FlowRunState.id.in_([r["id"] for r in insert_flow_run_states]), ) - .values(state_id=frs_model.id) + .values(state_id=FlowRunState.id) # no need to synchronize as these flow runs are entirely new .execution_options(synchronize_session=False) ) return stmt + @db_injector async def get_flow_run_notifications_from_queue( - self, session: AsyncSession, limit: int - ) -> List: + self, db: PrefectDBInterface, session: AsyncSession, limit: int + ) -> Sequence[FlowRunNotificationsFromQueue]: + Flow, FlowRun = db.Flow, db.FlowRun + FlowRunNotificationPolicy = db.FlowRunNotificationPolicy + FlowRunNotificationQueue = db.FlowRunNotificationQueue + FlowRunState = db.FlowRunState # including this as a subquery in the where clause of the # `queued_notifications` statement below, leads to errors where the limit # is not respected if it is 1. pulling this out into a CTE statement # prevents this. see link for more details: # https://www.postgresql.org/message-id/16497.1553640836%40sss.pgh.pa.us queued_notifications_ids = ( - sa.select(orm_models.FlowRunNotificationQueue.id) - .select_from(orm_models.FlowRunNotificationQueue) - .order_by(orm_models.FlowRunNotificationQueue.updated) + sa.select(FlowRunNotificationQueue.id) + .select_from(FlowRunNotificationQueue) + .order_by(FlowRunNotificationQueue.updated) .limit(limit) .with_for_update(skip_locked=True) ).cte("queued_notifications_ids") queued_notifications = ( - sa.delete(orm_models.FlowRunNotificationQueue) + sa.delete(FlowRunNotificationQueue) .returning( - orm_models.FlowRunNotificationQueue.id, - orm_models.FlowRunNotificationQueue.flow_run_notification_policy_id, - orm_models.FlowRunNotificationQueue.flow_run_state_id, - ) - .where( - orm_models.FlowRunNotificationQueue.id.in_( - sa.select(queued_notifications_ids) - ) + FlowRunNotificationQueue.id, + FlowRunNotificationQueue.flow_run_notification_policy_id, + FlowRunNotificationQueue.flow_run_state_id, ) + .where(FlowRunNotificationQueue.id.in_(sa.select(queued_notifications_ids))) .cte("queued_notifications") ) - notification_details_stmt = ( + notification_details_stmt: sa.Select[FlowRunNotificationsFromQueue] = ( sa.select( queued_notifications.c.id.label("queue_id"), - orm_models.FlowRunNotificationPolicy.id.label( - "flow_run_notification_policy_id" - ), - orm_models.FlowRunNotificationPolicy.message_template.label( + FlowRunNotificationPolicy.id.label("flow_run_notification_policy_id"), + FlowRunNotificationPolicy.message_template.label( "flow_run_notification_policy_message_template" ), - orm_models.FlowRunNotificationPolicy.block_document_id, - orm_models.Flow.id.label("flow_id"), - orm_models.Flow.name.label("flow_name"), - orm_models.FlowRun.id.label("flow_run_id"), - orm_models.FlowRun.name.label("flow_run_name"), - orm_models.FlowRun.parameters.label("flow_run_parameters"), - orm_models.FlowRunState.type.label("flow_run_state_type"), - orm_models.FlowRunState.name.label("flow_run_state_name"), - orm_models.FlowRunState.timestamp.label("flow_run_state_timestamp"), - orm_models.FlowRunState.message.label("flow_run_state_message"), + FlowRunNotificationPolicy.block_document_id, + Flow.id.label("flow_id"), + Flow.name.label("flow_name"), + FlowRun.id.label("flow_run_id"), + FlowRun.name.label("flow_run_name"), + FlowRun.parameters.label("flow_run_parameters"), + FlowRunState.type.label("flow_run_state_type"), + FlowRunState.name.label("flow_run_state_name"), + FlowRunState.timestamp.label("flow_run_state_timestamp"), + FlowRunState.message.label("flow_run_state_message"), ) .select_from(queued_notifications) .join( - orm_models.FlowRunNotificationPolicy, + FlowRunNotificationPolicy, queued_notifications.c.flow_run_notification_policy_id - == orm_models.FlowRunNotificationPolicy.id, - ) - .join( - orm_models.FlowRunState, - queued_notifications.c.flow_run_state_id == orm_models.FlowRunState.id, - ) - .join( - orm_models.FlowRun, - orm_models.FlowRunState.flow_run_id == orm_models.FlowRun.id, + == FlowRunNotificationPolicy.id, ) .join( - orm_models.Flow, - orm_models.FlowRun.flow_id == orm_models.Flow.id, + FlowRunState, + queued_notifications.c.flow_run_state_id == FlowRunState.id, ) + .join(FlowRun, FlowRunState.flow_run_id == FlowRun.id) + .join(Flow, FlowRun.flow_id == Flow.id) ) result = await session.execute(notification_details_stmt) - return result.fetchall() + return result.t.fetchall() @property - def _get_scheduled_flow_runs_from_work_pool_template_path(self): + def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: """ Template for the query to get scheduled flow runs from a work pool """ return "postgres/get-runs-from-worker-queues.sql.jinja" - async def flow_run_graph_v2( - self, - session: AsyncSession, - flow_run_id: UUID, - since: datetime.datetime, - max_nodes: int, - max_artifacts: int, - ) -> Graph: - """Returns the query that selects all of the nodes and edges for a flow run - graph (version 2).""" - result = await session.execute( + @db_injector + def _build_flow_run_graph_v2_query( + self, db: PrefectDBInterface + ) -> sa.Select[FlowRunGraphV2Node]: + """Postgresql version of the V2 FlowRun graph data query + + This SQLA query is built just once and then cached per DB interface + + """ + # the parameters this query takes as inputs + param_flow_run_id = sa.bindparam("flow_run_id", type_=UUIDTypeDecorator) + param_since = sa.bindparam("since", type_=Timestamp) + param_max_nodes = sa.bindparam("max_nodes", type_=sa.Integer) + + Flow, FlowRun, TaskRun = db.Flow, db.FlowRun, db.TaskRun + input = sa.func.jsonb_each(TaskRun.task_inputs).table_valued( + "key", "value", name="input" + ) + argument = ( + sa.func.jsonb_array_elements(input.c.value, type_=postgresql.JSONB()) + .table_valued(sa.column("value", postgresql.JSONB())) + .render_derived(name="argument") + ) + edges = ( sa.select( + sa.case((FlowRun.id.is_not(None), "flow-run"), else_="task-run").label( + "kind" + ), + sa.func.coalesce(FlowRun.id, TaskRun.id).label("id"), + sa.func.coalesce(Flow.name + " / " + FlowRun.name, TaskRun.name).label( + "label" + ), + sa.func.coalesce(FlowRun.state_type, TaskRun.state_type).label( + "state_type" + ), + sa.func.coalesce( + FlowRun.start_time, + FlowRun.expected_start_time, + TaskRun.start_time, + TaskRun.expected_start_time, + ).label("start_time"), sa.func.coalesce( - orm_models.FlowRun.start_time, - orm_models.FlowRun.expected_start_time, + FlowRun.end_time, + TaskRun.end_time, + sa.case( + ( + TaskRun.state_type == StateType.COMPLETED, + TaskRun.expected_start_time, + ), + else_=sa.null(), + ), + ).label("end_time"), + sa.cast(argument.c.value["id"].astext, type_=UUIDTypeDecorator).label( + "parent" ), - orm_models.FlowRun.end_time, - ).where( - orm_models.FlowRun.id == flow_run_id, + (input.c.key == "__parents__").label("has_encapsulating_task"), + ) + .join_from(TaskRun, input, onclause=sa.true(), isouter=True) + .join(argument, onclause=sa.true(), isouter=True) + .join( + FlowRun, + isouter=True, + onclause=FlowRun.parent_task_run_id == TaskRun.id, + ) + .join(Flow, isouter=True, onclause=Flow.id == FlowRun.flow_id) + .where( + TaskRun.flow_run_id == param_flow_run_id, + TaskRun.state_type != StateType.PENDING, + sa.func.coalesce( + FlowRun.start_time, + FlowRun.expected_start_time, + TaskRun.start_time, + TaskRun.expected_start_time, + ).is_not(None), + ) + # -- the order here is important to speed up building the two sets of + # -- edges in the with_parents and with_children CTEs below + .order_by(sa.func.coalesce(FlowRun.id, TaskRun.id)) + ).cte("edges") + children, parents = edges.alias("children"), edges.alias("parents") + with_encapsulating = ( + sa.select( + children.c.id, + sa.func.array_agg( + postgresql.aggregate_order_by(parents.c.id, parents.c.start_time) + ).label("encapsulating_ids"), + ) + .join(parents, onclause=parents.c.id == children.c.parent) + .where(children.c.has_encapsulating_task.is_(True)) + .group_by(children.c.id) + ).cte("with_encapsulating") + with_parents = ( + sa.select( + children.c.id, + sa.func.array_agg( + postgresql.aggregate_order_by(parents.c.id, parents.c.start_time) + ).label("parent_ids"), ) + .join(parents, onclause=parents.c.id == children.c.parent) + .where(children.c.has_encapsulating_task.is_distinct_from(True)) + .group_by(children.c.id) + .cte("with_parents") ) - try: - start_time, end_time = result.one() - except NoResultFound: - raise ObjectNotFoundError(f"Flow run {flow_run_id} not found") - - query = sa.text( - """ - WITH - edges AS ( - SELECT CASE - WHEN subflow.id IS NOT NULL THEN 'flow-run' - ELSE 'task-run' - END as kind, - COALESCE(subflow.id, task_run.id) as id, - COALESCE(flow.name || ' / ' || subflow.name, task_run.name) as label, - COALESCE(subflow.state_type, task_run.state_type) as state_type, - COALESCE( - subflow.start_time, - subflow.expected_start_time, - task_run.start_time, - task_run.expected_start_time - ) as start_time, - COALESCE( - subflow.end_time, - task_run.end_time, - CASE - WHEN task_run.state_type = 'COMPLETED' - THEN task_run.expected_start_time - ELSE NULL - END - ) as end_time, - (argument->>'id')::uuid as parent, - input.key = '__parents__' as has_encapsulating_task - FROM task_run - LEFT JOIN jsonb_each(task_run.task_inputs) as input ON true - LEFT JOIN jsonb_array_elements(input.value) as argument ON true - LEFT JOIN flow_run as subflow - ON subflow.parent_task_run_id = task_run.id - LEFT JOIN flow - ON flow.id = subflow.flow_id - WHERE task_run.flow_run_id = :flow_run_id AND - task_run.state_type <> 'PENDING' AND - COALESCE( - subflow.start_time, - subflow.expected_start_time, - task_run.start_time, - task_run.expected_start_time - ) IS NOT NULL - - -- the order here is important to speed up building the two sets of - -- edges in the with_parents and with_children CTEs below - ORDER BY COALESCE(subflow.id, task_run.id) - ), - with_encapsulating AS ( - SELECT children.id, - array_agg(parents.id order by parents.start_time) as encapsulating_ids - FROM edges as children - INNER JOIN edges as parents - ON parents.id = children.parent - WHERE children.has_encapsulating_task is True - GROUP BY children.id - ), - with_parents AS ( - SELECT children.id, - array_agg(parents.id order by parents.start_time) as parent_ids - FROM edges as children - INNER JOIN edges as parents - ON parents.id = children.parent - WHERE children.has_encapsulating_task is FALSE OR children.has_encapsulating_task is NULL - GROUP BY children.id - ), - with_children AS ( - SELECT parents.id, - array_agg(children.id order by children.start_time) as child_ids - FROM edges as parents - INNER JOIN edges as children - ON children.parent = parents.id - WHERE children.has_encapsulating_task is FALSE OR children.has_encapsulating_task is NULL - GROUP BY parents.id - ), - nodes AS ( - SELECT DISTINCT ON (edges.id) - edges.kind, - edges.id, - edges.label, - edges.state_type, - edges.start_time, - edges.end_time, - with_parents.parent_ids, - with_children.child_ids, - with_encapsulating.encapsulating_ids - FROM edges - LEFT JOIN with_parents - ON with_parents.id = edges.id - LEFT JOIN with_children - ON with_children.id = edges.id - LEFT JOIN with_encapsulating - ON with_encapsulating.id = edges.id + with_children = ( + sa.select( + parents.c.id, + sa.func.array_agg( + postgresql.aggregate_order_by(children.c.id, children.c.start_time) + ).label("child_ids"), ) - SELECT kind, - id, - label, - state_type, - start_time, - end_time, - parent_ids, - child_ids, - encapsulating_ids - FROM nodes - WHERE end_time IS NULL OR end_time >= :since - ORDER BY start_time, end_time - LIMIT :max_nodes - ; - """ + .join(children, onclause=children.c.parent == parents.c.id) + .where(children.c.has_encapsulating_task.is_distinct_from(True)) + .group_by(parents.c.id) + .cte("with_children") ) - query = query.bindparams( - sa.bindparam("flow_run_id", value=flow_run_id), - sa.bindparam("since", value=since), - sa.bindparam("max_nodes", value=max_nodes + 1), + graph = ( + sa.select( + edges.c.kind, + edges.c.id, + edges.c.label, + edges.c.state_type, + edges.c.start_time, + edges.c.end_time, + with_parents.c.parent_ids, + with_children.c.child_ids, + with_encapsulating.c.encapsulating_ids, + ) + .distinct(edges.c.id) + .join(with_parents, isouter=True, onclause=with_parents.c.id == edges.c.id) + .join( + with_children, isouter=True, onclause=with_children.c.id == edges.c.id + ) + .join( + with_encapsulating, + isouter=True, + onclause=with_encapsulating.c.id == edges.c.id, + ) + .cte("nodes") ) - - results = await session.execute(query) - - graph_artifacts = await self._get_flow_run_graph_artifacts( - session, flow_run_id, max_artifacts + query = ( + sa.select( + graph.c.kind, + graph.c.id, + graph.c.label, + graph.c.state_type, + graph.c.start_time, + graph.c.end_time, + graph.c.parent_ids, + graph.c.child_ids, + graph.c.encapsulating_ids, + ) + .where(sa.or_(graph.c.end_time.is_(None), graph.c.end_time >= param_since)) + .order_by(graph.c.start_time, graph.c.end_time) + .limit(param_max_nodes) ) - graph_states = await self._get_flow_run_graph_states(session, flow_run_id) + return cast(sa.Select[FlowRunGraphV2Node], query) - nodes: List[Tuple[UUID, Node]] = [] - root_node_ids: List[UUID] = [] - - for row in results: - if not row.parent_ids: - root_node_ids.append(row.id) - nodes.append( - ( - row.id, - Node( - kind=row.kind, - id=row.id, - label=row.label, - state_type=row.state_type, - start_time=row.start_time, - end_time=row.end_time, - parents=[Edge(id=id) for id in row.parent_ids or []], - children=[Edge(id=id) for id in row.child_ids or []], - # ensure encapsulating_ids is deduplicated - # so parents only show up once - encapsulating=[ - Edge(id=id) - for id in ( - list(set(row.encapsulating_ids)) - if row.encapsulating_ids - else [] - ) - ], - artifacts=graph_artifacts.get(row.id, []), - ), - ) - ) +class UUIDList(sa.TypeDecorator[list[UUID]]): + """Map a JSON list of strings back to a list of UUIDs at the result loading stage""" - if len(nodes) > max_nodes: - raise FlowRunGraphTooLarge( - f"The graph of flow run {flow_run_id} has more than " - f"{max_nodes} nodes." - ) + impl: Union[TypeEngine[Any], type[TypeEngine[Any]]] = sa.JSON() - return Graph( - start_time=start_time, - end_time=end_time, - root_node_ids=root_node_ids, - nodes=nodes, - artifacts=graph_artifacts.get(None, []), - states=graph_states, - ) + def process_result_value( + self, value: Optional[list[Union[str, UUID]]], dialect: sa.Dialect + ) -> Optional[list[UUID]]: + if value is None: + return value + return [v if isinstance(v, UUID) else UUID(v) for v in value] class AioSqliteQueryComponents(BaseQueryComponents): # --- Sqlite-specific SqlAlchemy bindings - def insert(self, obj) -> sqlite.Insert: + def insert(self, obj: type[orm_models.Base]) -> sqlite.Insert: return sqlite.insert(obj) # --- Sqlite-specific JSON handling @property - def uses_json_strings(self): + def uses_json_strings(self) -> bool: return True - def cast_to_json(self, json_obj): + def cast_to_json(self, json_obj: sa.ColumnElement[T]) -> sa.ColumnElement[T]: return sa.func.json(json_obj) - def build_json_object(self, *args): + def build_json_object( + self, *args: Union[str, sa.ColumnElement[Any]] + ) -> sa.ColumnElement[Any]: return sa.func.json_object(*args) - def json_arr_agg(self, json_array): + def json_arr_agg(self, json_array: sa.ColumnElement[Any]) -> sa.ColumnElement[Any]: return sa.func.json_group_array(json_array) # --- Sqlite-optimized subqueries def make_timestamp_intervals( self, - start_time: datetime.datetime, - end_time: datetime.datetime, + start_time: pendulum.DateTime, + end_time: pendulum.DateTime, interval: datetime.timedelta, - ): - from prefect.server.utilities.database import Timestamp - - # validate inputs - start_time = pendulum.instance(start_time) - end_time = pendulum.instance(end_time) - assert isinstance(interval, datetime.timedelta) - - return ( - sa.text( - r""" - -- recursive CTE to mimic the behavior of `generate_series`, - -- which is only available as a compiled extension - WITH RECURSIVE intervals(interval_start, interval_end, counter) AS ( - VALUES( - strftime('%Y-%m-%d %H:%M:%f000', :start_time), - strftime('%Y-%m-%d %H:%M:%f000', :start_time, :interval), - 1 - ) - - UNION ALL - - SELECT interval_end, strftime('%Y-%m-%d %H:%M:%f000', interval_end, :interval), counter + 1 - FROM intervals - -- subtract interval because recursive where clauses are effectively evaluated on a t-1 lag - WHERE - interval_start < strftime('%Y-%m-%d %H:%M:%f000', :end_time, :negative_interval) - -- don't compute more than 500 intervals - AND counter < 500 - ) - SELECT * FROM intervals - """ - ) - .bindparams( - start_time=str(start_time), - end_time=str(end_time), - interval=f"+{interval.total_seconds()} seconds", - negative_interval=f"-{interval.total_seconds()} seconds", - ) - .columns(interval_start=Timestamp(), interval_end=Timestamp()) + ) -> sa.Select[tuple[pendulum.DateTime, pendulum.DateTime]]: + start = sa.bindparam("start_time", start_time, Timestamp) + # subtract interval because recursive where clauses are effectively evaluated on a t-1 lag + stop = sa.bindparam("end_time", end_time - interval, Timestamp) + step = sa.bindparam("interval", interval, sa.Interval) + + one = sa.literal(1, literal_execute=True) + + # recursive CTE to mimic the behavior of `generate_series`, which is + # only available as a compiled extension + base_case = sa.select( + start.label("interval_start"), + sa.func.date_add(start, step).label("interval_end"), + one.label("counter"), + ).cte(recursive=True) + recursive_case = sa.select( + base_case.c.interval_end, + sa.func.date_add(base_case.c.interval_end, step), + base_case.c.counter + one, + ).where( + base_case.c.interval_start < stop, + # don't compute more than 500 intervals + base_case.c.counter < 500, ) + cte = base_case.union_all(recursive_case) + + return sa.select(cte.c.interval_start, cte.c.interval_end) + @db_injector def set_state_id_on_inserted_flow_runs_statement( self, - fr_model, - frs_model, - inserted_flow_run_ids, - insert_flow_run_states, - ): + db: PrefectDBInterface, + inserted_flow_run_ids: Sequence[UUID], + insert_flow_run_states: Iterable[dict[str, Any]], + ) -> sa.Update: """Given a list of flow run ids and associated states, set the state_id to the appropriate state for all flow runs""" + fr_model, frs_model = db.FlowRun, db.FlowRunState # sqlite requires a correlated subquery to update from another table subquery = ( sa.select(frs_model.id) @@ -1070,9 +1023,10 @@ def set_state_id_on_inserted_flow_runs_statement( ) return stmt + @db_injector async def get_flow_run_notifications_from_queue( - self, session: AsyncSession, limit: int - ) -> List: + self, db: PrefectDBInterface, session: AsyncSession, limit: int + ) -> Sequence[FlowRunNotificationsFromQueue]: """ Sqlalchemy has no support for DELETE RETURNING in sqlite (as of May 2022) so instead we issue two queries; one to get queued notifications and a second to delete @@ -1080,60 +1034,52 @@ async def get_flow_run_notifications_from_queue( running. """ - notification_details_stmt = ( + Flow, FlowRun = db.Flow, db.FlowRun + FlowRunNotificationPolicy = db.FlowRunNotificationPolicy + FlowRunNotificationQueue = db.FlowRunNotificationQueue + FlowRunState = db.FlowRunState + + notification_details_stmt: sa.Select[FlowRunNotificationsFromQueue] = ( sa.select( - orm_models.FlowRunNotificationQueue.id.label("queue_id"), - orm_models.FlowRunNotificationPolicy.id.label( - "flow_run_notification_policy_id" - ), - orm_models.FlowRunNotificationPolicy.message_template.label( + FlowRunNotificationQueue.id.label("queue_id"), + FlowRunNotificationPolicy.id.label("flow_run_notification_policy_id"), + FlowRunNotificationPolicy.message_template.label( "flow_run_notification_policy_message_template" ), - orm_models.FlowRunNotificationPolicy.block_document_id, - orm_models.Flow.id.label("flow_id"), - orm_models.Flow.name.label("flow_name"), - orm_models.FlowRun.id.label("flow_run_id"), - orm_models.FlowRun.name.label("flow_run_name"), - orm_models.FlowRun.parameters.label("flow_run_parameters"), - orm_models.FlowRunState.type.label("flow_run_state_type"), - orm_models.FlowRunState.name.label("flow_run_state_name"), - orm_models.FlowRunState.timestamp.label("flow_run_state_timestamp"), - orm_models.FlowRunState.message.label("flow_run_state_message"), + FlowRunNotificationPolicy.block_document_id, + Flow.id.label("flow_id"), + Flow.name.label("flow_name"), + FlowRun.id.label("flow_run_id"), + FlowRun.name.label("flow_run_name"), + FlowRun.parameters.label("flow_run_parameters"), + FlowRunState.type.label("flow_run_state_type"), + FlowRunState.name.label("flow_run_state_name"), + FlowRunState.timestamp.label("flow_run_state_timestamp"), + FlowRunState.message.label("flow_run_state_message"), ) - .select_from(orm_models.FlowRunNotificationQueue) + .select_from(FlowRunNotificationQueue) .join( - orm_models.FlowRunNotificationPolicy, - orm_models.FlowRunNotificationQueue.flow_run_notification_policy_id - == orm_models.FlowRunNotificationPolicy.id, + FlowRunNotificationPolicy, + FlowRunNotificationQueue.flow_run_notification_policy_id + == FlowRunNotificationPolicy.id, ) .join( - orm_models.FlowRunState, - orm_models.FlowRunNotificationQueue.flow_run_state_id - == orm_models.FlowRunState.id, + FlowRunState, + FlowRunNotificationQueue.flow_run_state_id == FlowRunState.id, ) - .join( - orm_models.FlowRun, - orm_models.FlowRunState.flow_run_id == orm_models.FlowRun.id, - ) - .join( - orm_models.Flow, - orm_models.FlowRun.flow_id == orm_models.Flow.id, - ) - .order_by(orm_models.FlowRunNotificationQueue.updated) + .join(FlowRun, FlowRunState.flow_run_id == FlowRun.id) + .join(Flow, FlowRun.flow_id == Flow.id) + .order_by(FlowRunNotificationQueue.updated) .limit(limit) ) result = await session.execute(notification_details_stmt) - notifications = result.fetchall() + notifications = result.t.fetchall() # delete the notifications delete_stmt = ( - sa.delete(orm_models.FlowRunNotificationQueue) - .where( - orm_models.FlowRunNotificationQueue.id.in_( - [n.queue_id for n in notifications] - ) - ) + sa.delete(FlowRunNotificationQueue) + .where(FlowRunNotificationQueue.id.in_([n.queue_id for n in notifications])) .execution_options(synchronize_session="fetch") ) @@ -1141,31 +1087,21 @@ async def get_flow_run_notifications_from_queue( return notifications - async def _handle_filtered_block_document_ids( - self, session, filtered_block_documents_query - ): - """ - On SQLite, including the filtered block document parameters confuses the - compiler and it passes positional parameters in the wrong order (it is - unclear why; SQLalchemy manual compilation works great. Switching to - `named` paramstyle also works but fails elsewhere in the codebase). To - resolve this, we materialize the filtered id query into a literal set of - IDs rather than leaving it as a SQL select. - """ - result = await session.execute(filtered_block_documents_query) - return result.scalars().all() - + @db_injector def _get_scheduled_flow_runs_join( self, - work_queue_query, + db: PrefectDBInterface, + work_queue_query: sa.CTE, limit_per_queue: Optional[int], - scheduled_before: Optional[datetime.datetime], - ): + scheduled_before: Optional[pendulum.DateTime], + ) -> tuple[sa.FromClause, sa.ColumnExpressionArgument[bool]]: # precompute for readability + FlowRun = db.FlowRun + scheduled_before_clause = ( - orm_models.FlowRun.next_scheduled_start_time <= scheduled_before + FlowRun.next_scheduled_start_time <= scheduled_before if scheduled_before is not None - else True + else sa.true() ) # select scheduled flow runs, ordered by scheduled start time per queue @@ -1174,17 +1110,14 @@ def _get_scheduled_flow_runs_join( ( sa.func.row_number() .over( - partition_by=[orm_models.FlowRun.work_queue_name], - order_by=orm_models.FlowRun.next_scheduled_start_time, + partition_by=[FlowRun.work_queue_name], + order_by=FlowRun.next_scheduled_start_time, ) .label("rank") ), - orm_models.FlowRun, - ) - .where( - orm_models.FlowRun.state_type == "SCHEDULED", - scheduled_before_clause, + FlowRun, ) + .where(FlowRun.state_type == StateType.SCHEDULED, scheduled_before_clause) .subquery("scheduled_flow_runs") ) @@ -1195,9 +1128,7 @@ def _get_scheduled_flow_runs_join( # in the join, only keep flow runs whose rank is less than or equal to the # available slots for each queue join_criteria = sa.and_( - self._flow_run_work_queue_join_clause( - scheduled_flow_runs.c, orm_models.WorkQueue - ), + scheduled_flow_runs.c.work_queue_name == db.WorkQueue.name, scheduled_flow_runs.c.rank <= sa.func.min( sa.func.coalesce(work_queue_query.c.available_slots, limit), limit @@ -1210,239 +1141,156 @@ def _get_scheduled_flow_runs_join( # ------------------------------------------------------- @property - def _get_scheduled_flow_runs_from_work_pool_template_path(self): + def _get_scheduled_flow_runs_from_work_pool_template_path(self) -> str: """ Template for the query to get scheduled flow runs from a work pool """ return "sqlite/get-runs-from-worker-queues.sql.jinja" - async def flow_run_graph_v2( - self, - session: AsyncSession, - flow_run_id: UUID, - since: datetime.datetime, - max_nodes: int, - max_artifacts: int, - ) -> Graph: - """Returns the query that selects all of the nodes and edges for a flow run - graph (version 2).""" - result = await session.execute( + @db_injector + def _build_flow_run_graph_v2_query( + self, db: PrefectDBInterface + ) -> sa.Select[FlowRunGraphV2Node]: + """Postgresql version of the V2 FlowRun graph data query + + This SQLA query is built just once and then cached per DB interface + + """ + # the parameters this query takes as inputs + param_flow_run_id = sa.bindparam("flow_run_id", type_=UUIDTypeDecorator) + param_since = sa.bindparam("since", type_=Timestamp) + param_max_nodes = sa.bindparam("max_nodes", type_=sa.Integer) + + Flow, FlowRun, TaskRun = db.Flow, db.FlowRun, db.TaskRun + input = sa.func.json_each(TaskRun.task_inputs).table_valued( + "key", "value", name="input" + ) + argument = sa.func.json_each( + input.c.value, type_=postgresql.JSON() + ).table_valued("key", sa.column("value", postgresql.JSON()), name="argument") + edges = ( sa.select( - sa.func.coalesce( - orm_models.FlowRun.start_time, - orm_models.FlowRun.expected_start_time, + sa.case((FlowRun.id.is_not(None), "flow-run"), else_="task-run").label( + "kind" + ), + sa.func.coalesce(FlowRun.id, TaskRun.id).label("id"), + sa.func.coalesce(Flow.name + " / " + FlowRun.name, TaskRun.name).label( + "label" ), - orm_models.FlowRun.end_time, - ).where( - orm_models.FlowRun.id == flow_run_id, + sa.func.coalesce(FlowRun.state_type, TaskRun.state_type).label( + "state_type" + ), + sa.func.coalesce( + FlowRun.start_time, + FlowRun.expected_start_time, + TaskRun.start_time, + TaskRun.expected_start_time, + ).label("start_time"), + sa.func.coalesce( + FlowRun.end_time, + TaskRun.end_time, + sa.case( + ( + TaskRun.state_type == StateType.COMPLETED, + TaskRun.expected_start_time, + ), + else_=sa.null(), + ), + ).label("end_time"), + argument.c.value["id"].astext.label("parent"), + (input.c.key == "__parents__").label("has_encapsulating_task"), ) - ) - try: - start_time, end_time = result.one() - except NoResultFound: - raise ObjectNotFoundError(f"Flow run {flow_run_id} not found") - - query = sa.text( - """ - WITH - edges AS ( - SELECT CASE - WHEN subflow.id IS NOT NULL THEN 'flow-run' - ELSE 'task-run' - END as kind, - COALESCE(subflow.id, task_run.id) as id, - COALESCE(flow.name || ' / ' || subflow.name, task_run.name) as label, - COALESCE(subflow.state_type, task_run.state_type) as state_type, - COALESCE( - subflow.start_time, - subflow.expected_start_time, - task_run.start_time, - task_run.expected_start_time - ) as start_time, - COALESCE( - subflow.end_time, - task_run.end_time, - CASE - WHEN task_run.state_type = 'COMPLETED' - THEN task_run.expected_start_time - ELSE NULL - END - ) as end_time, - json_extract(argument.value, '$.id') as parent, - input.key = '__parents__' as has_encapsulating_task - FROM task_run - LEFT JOIN json_each(task_run.task_inputs) as input ON true - LEFT JOIN json_each(input.value) as argument ON true - LEFT JOIN flow_run as subflow - ON subflow.parent_task_run_id = task_run.id - LEFT JOIN flow - ON flow.id = subflow.flow_id - WHERE task_run.flow_run_id = :flow_run_id AND - task_run.state_type <> 'PENDING' AND - COALESCE( - subflow.start_time, - subflow.expected_start_time, - task_run.start_time, - task_run.expected_start_time - ) IS NOT NULL - - -- the order here is important to speed up building the two sets of - -- edges in the with_parents and with_children CTEs below - ORDER BY COALESCE(subflow.id, task_run.id) - ), - with_encapsulating AS ( - SELECT children.id, - group_concat(parents.id) as encapsulating_ids - FROM edges as children - INNER JOIN edges as parents - ON parents.id = children.parent - WHERE children.has_encapsulating_task IS TRUE - GROUP BY children.id - ), - with_parents AS ( - SELECT children.id, - group_concat(parents.id) as parent_ids - FROM edges as children - INNER JOIN edges as parents - ON parents.id = children.parent - WHERE children.has_encapsulating_task is FALSE OR children.has_encapsulating_task IS NULL - GROUP BY children.id - ), - with_children AS ( - SELECT parents.id, - group_concat(children.id) as child_ids - FROM edges as parents - INNER JOIN edges as children - ON children.parent = parents.id - WHERE children.has_encapsulating_task IS FALSE OR children.has_encapsulating_task IS NULL - GROUP BY parents.id - ), - nodes AS ( - SELECT DISTINCT - edges.id, - edges.kind, - edges.id, - edges.label, - edges.state_type, - edges.start_time, - edges.end_time, - with_parents.parent_ids, - with_children.child_ids, - with_encapsulating.encapsulating_ids - FROM edges - LEFT JOIN with_parents - ON with_parents.id = edges.id - LEFT JOIN with_children - ON with_children.id = edges.id - LEFT JOIN with_encapsulating - ON with_encapsulating.id = edges.id + .join_from(TaskRun, input, onclause=sa.true(), isouter=True) + .join(argument, onclause=sa.true(), isouter=True) + .join( + FlowRun, + isouter=True, + onclause=FlowRun.parent_task_run_id == TaskRun.id, ) - SELECT kind, - id, - label, - state_type, - start_time, - end_time, - parent_ids, - child_ids, - encapsulating_ids - FROM nodes - WHERE end_time IS NULL OR end_time >= :since - ORDER BY start_time, end_time - LIMIT :max_nodes - ; - """ - ) - - # SQLite needs this to be a Python datetime object - since = datetime.datetime( - since.year, - since.month, - since.day, - since.hour, - since.minute, - since.second, - since.microsecond, - tzinfo=since.tzinfo, + .join(Flow, isouter=True, onclause=Flow.id == FlowRun.flow_id) + .where( + TaskRun.flow_run_id == param_flow_run_id, + TaskRun.state_type != StateType.PENDING, + sa.func.coalesce( + FlowRun.start_time, + FlowRun.expected_start_time, + TaskRun.start_time, + TaskRun.expected_start_time, + ).is_not(None), + ) + # -- the order here is important to speed up building the two sets of + # -- edges in the with_parents and with_children CTEs below + .order_by(sa.func.coalesce(FlowRun.id, TaskRun.id)) + ).cte("edges") + children, parents = edges.alias("children"), edges.alias("parents") + with_encapsulating = ( + sa.select( + children.c.id, + sa.func.json_group_array(parents.c.id).label("encapsulating_ids"), + ) + .join(parents, onclause=parents.c.id == children.c.parent) + .where(children.c.has_encapsulating_task.is_(True)) + .group_by(children.c.id) + ).cte("with_encapsulating") + with_parents = ( + sa.select( + children.c.id, + sa.func.json_group_array(parents.c.id).label("parent_ids"), + ) + .join(parents, onclause=parents.c.id == children.c.parent) + .where(children.c.has_encapsulating_task.is_distinct_from(True)) + .group_by(children.c.id) + .cte("with_parents") ) - - query = query.bindparams( - sa.bindparam("flow_run_id", value=str(flow_run_id)), - sa.bindparam("since", value=since), - sa.bindparam("max_nodes", value=max_nodes + 1), + with_children = ( + sa.select( + parents.c.id, sa.func.json_group_array(children.c.id).label("child_ids") + ) + .join(children, onclause=children.c.parent == parents.c.id) + .where(children.c.has_encapsulating_task.is_distinct_from(True)) + .group_by(parents.c.id) + .cte("with_children") ) - results = await session.execute(query) - - graph_artifacts = await self._get_flow_run_graph_artifacts( - session, flow_run_id, max_artifacts + graph = ( + sa.select( + edges.c.kind, + edges.c.id, + edges.c.label, + edges.c.state_type, + edges.c.start_time, + edges.c.end_time, + with_parents.c.parent_ids, + with_children.c.child_ids, + with_encapsulating.c.encapsulating_ids, + ) + .distinct() + .join(with_parents, isouter=True, onclause=with_parents.c.id == edges.c.id) + .join( + with_children, isouter=True, onclause=with_children.c.id == edges.c.id + ) + .join( + with_encapsulating, + isouter=True, + onclause=with_encapsulating.c.id == edges.c.id, + ) + .cte("nodes") ) - graph_states = await self._get_flow_run_graph_states(session, flow_run_id) - nodes: List[Tuple[UUID, Node]] = [] - root_node_ids: List[UUID] = [] - - for row in results: - if not row.parent_ids: - root_node_ids.append(row.id) - - # With SQLite, some of the values are returned as strings rather than - # native Python objects, as they would be from PostgreSQL. These functions - # help smooth over those differences. - - def edges( - value: Union[str, Sequence[UUID], Sequence[str], None], - ) -> List[UUID]: - if not value: - return [] - if isinstance(value, str): - return [Edge(id=id) for id in value.split(",")] - return [Edge(id=id) for id in value] - - def time( - value: Union[str, datetime.datetime, None], - ) -> Optional[pendulum.DateTime]: - if not value: - return None - if isinstance(value, str): - return cast(pendulum.DateTime, pendulum.parse(value)) - return pendulum.instance(value) - - nodes.append( - ( - row.id, - Node( - kind=row.kind, - id=row.id, - label=row.label, - state_type=row.state_type, - start_time=time(row.start_time), - end_time=time(row.end_time), - parents=edges(row.parent_ids), - children=edges(row.child_ids), - # ensure encapsulating_ids is deduplicated - # so parents only show up once - encapsulating=edges( - list(set(row.encapsulating_ids.split(","))) - if row.encapsulating_ids - else None - ), - artifacts=graph_artifacts.get(UUID(row.id), []), - ), - ) + query = ( + sa.select( + graph.c.kind, + graph.c.id, + graph.c.label, + graph.c.state_type, + graph.c.start_time, + graph.c.end_time, + sa.type_coerce(graph.c.parent_ids, UUIDList), + sa.type_coerce(graph.c.child_ids, UUIDList), + sa.type_coerce(graph.c.encapsulating_ids, UUIDList), ) - - if len(nodes) > max_nodes: - raise FlowRunGraphTooLarge( - f"The graph of flow run {flow_run_id} has more than " - f"{max_nodes} nodes." - ) - - return Graph( - start_time=start_time, - end_time=end_time, - root_node_ids=root_node_ids, - nodes=nodes, - artifacts=graph_artifacts.get(None, []), - states=graph_states, + .where(sa.or_(graph.c.end_time.is_(None), graph.c.end_time >= param_since)) + .order_by(graph.c.start_time, graph.c.end_time) + .limit(param_max_nodes) ) + return cast(sa.Select[FlowRunGraphV2Node], query) diff --git a/src/prefect/server/events/counting.py b/src/prefect/server/events/counting.py index ec14ad7f70a0b..cbd0ce8ca2f93 100644 --- a/src/prefect/server/events/counting.py +++ b/src/prefect/server/events/counting.py @@ -6,8 +6,7 @@ import sqlalchemy as sa from sqlalchemy.sql.selectable import Select -from prefect.server.database.dependencies import provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, provide_database_interface from prefect.types import DateTime from prefect.utilities.collections import AutoEnum diff --git a/src/prefect/server/events/filters.py b/src/prefect/server/events/filters.py index 8b87829f9614e..cb6cdf3711787 100644 --- a/src/prefect/server/events/filters.py +++ b/src/prefect/server/events/filters.py @@ -1,7 +1,8 @@ import sys +from collections.abc import Iterable from dataclasses import dataclass, field from datetime import timedelta -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple from uuid import UUID import pendulum @@ -9,13 +10,12 @@ from pydantic import Field, PrivateAttr from sqlalchemy.sql import Select -from prefect._internal.schemas.bases import PrefectBaseModel -from prefect.server.database import orm_models +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.schemas.filters import ( PrefectFilterBaseModel, PrefectOperatorFilterBaseModel, ) -from prefect.types import DateTime +from prefect.server.utilities.schemas.bases import PrefectBaseModel from prefect.utilities.collections import AutoEnum from .schemas.events import Event, Resource, ResourceSpecification @@ -23,6 +23,10 @@ if TYPE_CHECKING: from sqlalchemy.sql.expression import ColumnElement, ColumnExpressionArgument + DateTime = pendulum.DateTime +else: + from prefect.types import DateTime + class AutomationFilterCreated(PrefectFilterBaseModel): """Filter by `Automation.created`.""" @@ -32,11 +36,12 @@ class AutomationFilterCreated(PrefectFilterBaseModel): description="Only include automations created before this datetime", ) - def _get_filter_list(self) -> list: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnElement[bool]]: if self.before_ is not None: - filters.append(orm_models.Automation.created <= self.before_) - return filters + return [db.Automation.created <= self.before_] + return () class AutomationFilterName(PrefectFilterBaseModel): @@ -47,11 +52,10 @@ class AutomationFilterName(PrefectFilterBaseModel): description="Only include automations with names that match any of these strings", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list(self, db: PrefectDBInterface) -> list[sa.ColumnElement[bool]]: if self.any_ is not None: - filters.append(orm_models.Automation.name.in_(self.any_)) - return filters + return [db.Automation.name.in_(self.any_)] + return [] class AutomationFilter(PrefectOperatorFilterBaseModel): @@ -62,8 +66,10 @@ class AutomationFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `Automation.created`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.name is not None: filters.append(self.name.as_sql_filter()) @@ -76,14 +82,12 @@ def _get_filter_list(self) -> List: class EventDataFilter(PrefectBaseModel, extra="forbid"): """A base class for filtering event data.""" - _top_level_filter: Optional["EventFilter"] = PrivateAttr(None) + _top_level_filter: Optional[sa.Select[tuple[UUID]]] = PrivateAttr(None) def get_filters(self) -> List["EventDataFilter"]: filters: List[EventDataFilter] = [ filter - for filter in [ - getattr(self, name) for name, field in self.model_fields.items() - ] + for filter in [getattr(self, name) for name in self.model_fields] if isinstance(filter, EventDataFilter) ] for filter in filters: @@ -108,32 +112,30 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: class EventOccurredFilter(EventDataFilter): since: DateTime = Field( - default_factory=lambda: cast( - DateTime, - pendulum.now("UTC").start_of("day").subtract(days=180), - ), + default_factory=lambda: pendulum.now("UTC").start_of("day").subtract(days=180), description="Only include events after this time (inclusive)", ) until: DateTime = Field( - default_factory=lambda: cast(DateTime, pendulum.now("UTC")), + default_factory=lambda: pendulum.now("UTC"), description="Only include events prior to this time (inclusive)", ) def clamp(self, max_duration: timedelta): """Limit how far the query can look back based on the given duration""" earliest = pendulum.now("UTC") - max_duration - self.since = max(earliest, cast(pendulum.DateTime, self.since)) + self.since = max(earliest, self.since) def includes(self, event: Event) -> bool: return self.since <= event.occurred <= self.until - def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: - filters: List["ColumnExpressionArgument[bool]"] = [] - - filters.append(orm_models.Event.occurred >= self.since) - filters.append(orm_models.Event.occurred <= self.until) - - return filters + @db_injector + def build_where_clauses( + self, db: PrefectDBInterface + ) -> Sequence["ColumnExpressionArgument[bool]"]: + return [ + db.Event.occurred >= self.since, + db.Event.occurred <= self.until, + ] class EventNameFilter(EventDataFilter): @@ -170,34 +172,28 @@ def includes(self, event: Event) -> bool: return True - def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: + @db_injector + def build_where_clauses( + self, db: PrefectDBInterface + ) -> Sequence["ColumnExpressionArgument[bool]"]: filters: List["ColumnExpressionArgument[bool]"] = [] if self.prefix: filters.append( - sa.or_( - *[ - orm_models.Event.event.startswith(prefix) - for prefix in self.prefix - ] - ) + sa.or_(*(db.Event.event.startswith(prefix) for prefix in self.prefix)) ) if self.exclude_prefix: - filters.append( - sa.and_( - *[ - sa.not_(orm_models.Event.event.startswith(prefix)) - for prefix in self.exclude_prefix - ] - ) + filters.extend( + sa.not_(db.Event.event.startswith(prefix)) + for prefix in self.exclude_prefix ) if self.name: - filters.append(orm_models.Event.event.in_(self.name)) + filters.append(db.Event.event.in_(self.name)) if self.exclude_name: - filters.append(orm_models.Event.event.not_in(self.exclude_name)) + filters.append(db.Event.event.not_in(self.exclude_name)) return filters @@ -262,22 +258,25 @@ def includes(self, event: Event) -> bool: return True - def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: + @db_injector + def build_where_clauses( + self, db: PrefectDBInterface + ) -> Sequence["ColumnExpressionArgument[bool]"]: filters: List["ColumnExpressionArgument[bool]"] = [] # If we're doing an exact or prefix search on resource_id, this is efficient # enough to do on the events table without going to the event_resources table if self.id: - filters.append(orm_models.Event.resource_id.in_(self.id)) + filters.append(db.Event.resource_id.in_(self.id)) if self.id_prefix: filters.append( sa.or_( - *[ - orm_models.Event.resource_id.startswith(prefix) + *( + db.Event.resource_id.startswith(prefix) for prefix in self.id_prefix - ] + ) ) ) @@ -286,14 +285,14 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: # We are explicitly searching for the primary resource here so the # resource_role must be '' - label_filters = [orm_models.EventResource.resource_role == ""] + label_filters = [db.EventResource.resource_role == ""] # On the event_resources table, resource_id is unpacked # into a column, so we should search for it there if resource_ids := labels.pop("prefect.resource.id", None): label_ops = LabelOperations(resource_ids) - resource_id_column = orm_models.EventResource.resource_id + resource_id_column = db.EventResource.resource_id if values := label_ops.positive.simple: label_filters.append(resource_id_column.in_(values)) @@ -308,7 +307,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: for _, (label, values) in enumerate(labels.items()): label_ops = LabelOperations(values) - label_column = orm_models.EventResource.resource[label].astext + label_column = db.EventResource.resource[label].astext # With negative labels, the resource _must_ have the label if label_ops.negative.simple or label_ops.negative.prefixes: @@ -323,13 +322,9 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: for prefix in label_ops.negative.prefixes: label_filters.append(sa.not_(label_column.startswith(prefix))) - assert self._top_level_filter + assert self._top_level_filter is not None filters.append( - orm_models.Event.id.in_( - self._top_level_filter._scoped_event_resources().where( - *label_filters - ) - ) + db.Event.id.in_(self._top_level_filter.where(*label_filters)) ) return filters @@ -352,25 +347,28 @@ class EventRelatedFilter(EventDataFilter): None, description="Only include events for related resources with these labels" ) - def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: + @db_injector + def build_where_clauses( + self, db: PrefectDBInterface + ) -> Sequence["ColumnExpressionArgument[bool]"]: filters: List["ColumnExpressionArgument[bool]"] = [] if self.id: - filters.append(orm_models.EventResource.resource_id.in_(self.id)) + filters.append(db.EventResource.resource_id.in_(self.id)) if self.role: - filters.append(orm_models.EventResource.resource_role.in_(self.role)) + filters.append(db.EventResource.resource_role.in_(self.role)) if self.resources_in_roles: filters.append( sa.or_( - *[ + *( sa.and_( - orm_models.EventResource.resource_id == resource_id, - orm_models.EventResource.resource_role == role, + db.EventResource.resource_id == resource_id, + db.EventResource.resource_role == role, ) for resource_id, role in self.resources_in_roles - ] + ) ) ) @@ -383,7 +381,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: if resource_ids := labels.pop("prefect.resource.id", None): label_ops = LabelOperations(resource_ids) - resource_id_column = orm_models.EventResource.resource_id + resource_id_column = db.EventResource.resource_id if values := label_ops.positive.simple: label_filters.append(resource_id_column.in_(values)) @@ -395,13 +393,13 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: label_filters.append(sa.not_(resource_id_column.startswith(prefix))) if roles := labels.pop("prefect.resource.role", None): - label_filters.append(orm_models.EventResource.resource_role.in_(roles)) + label_filters.append(db.EventResource.resource_role.in_(roles)) if labels: for _, (label, values) in enumerate(labels.items()): label_ops = LabelOperations(values) - label_column = orm_models.EventResource.resource[label].astext + label_column = db.EventResource.resource[label].astext if label_ops.negative.simple or label_ops.negative.prefixes: label_filters.append(label_column.is_not(None)) @@ -423,14 +421,10 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: # also filter out primary resources (those with an empty role) for any of # these queries if not self.role: - filters.append(orm_models.EventResource.resource_role != "") + filters.append(db.EventResource.resource_role != "") - assert self._top_level_filter - filters = [ - orm_models.Event.id.in_( - self._top_level_filter._scoped_event_resources().where(*filters) - ) - ] + assert self._top_level_filter is not None + filters = [db.Event.id.in_(self._top_level_filter.where(*filters))] return filters @@ -470,17 +464,20 @@ def _includes(self, resource: Resource) -> bool: return True - def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: + @db_injector + def build_where_clauses( + self, db: PrefectDBInterface + ) -> Sequence["ColumnExpressionArgument[bool]"]: filters: List["ColumnExpressionArgument[bool]"] = [] if self.id: - filters.append(orm_models.EventResource.resource_id.in_(self.id)) + filters.append(db.EventResource.resource_id.in_(self.id)) if self.id_prefix: filters.append( sa.or_( *[ - orm_models.EventResource.resource_id.startswith(prefix) + db.EventResource.resource_id.startswith(prefix) for prefix in self.id_prefix ] ) @@ -495,7 +492,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: if resource_ids := labels.pop("prefect.resource.id", None): label_ops = LabelOperations(resource_ids) - resource_id_column = orm_models.EventResource.resource_id + resource_id_column = db.EventResource.resource_id if values := label_ops.positive.simple: label_filters.append(resource_id_column.in_(values)) @@ -507,13 +504,13 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: label_filters.append(sa.not_(resource_id_column.startswith(prefix))) if roles := labels.pop("prefect.resource.role", None): - label_filters.append(orm_models.EventResource.resource_role.in_(roles)) + label_filters.append(db.EventResource.resource_role.in_(roles)) if labels: for _, (label, values) in enumerate(labels.items()): label_ops = LabelOperations(values) - label_column = orm_models.EventResource.resource[label].astext + label_column = db.EventResource.resource[label].astext if label_ops.negative.simple or label_ops.negative.prefixes: label_filters.append(label_column.is_not(None)) @@ -530,19 +527,15 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: filters.append(sa.and_(*label_filters)) if filters: - assert self._top_level_filter - filters = [ - orm_models.Event.id.in_( - self._top_level_filter._scoped_event_resources().where(*filters) - ) - ] + assert self._top_level_filter is not None + filters = [db.Event.id.in_(self._top_level_filter.where(*filters))] return filters class EventIDFilter(EventDataFilter): id: Optional[List[UUID]] = Field( - None, description="Only include events with one of these IDs" + default=None, description="Only include events with one of these IDs" ) def includes(self, event: Event) -> bool: @@ -552,11 +545,14 @@ def includes(self, event: Event) -> bool: return True - def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: + @db_injector + def build_where_clauses( + self, db: PrefectDBInterface + ) -> Sequence["ColumnExpressionArgument[bool]"]: filters: List["ColumnExpressionArgument[bool]"] = [] if self.id: - filters.append(orm_models.Event.id.in_(self.id)) + filters.append(db.Event.id.in_(self.id)) return filters @@ -594,15 +590,20 @@ class EventFilter(EventDataFilter): description="The order to return filtered events", ) - def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: - self._top_level_filter = self - return super().build_where_clauses() + @db_injector + def build_where_clauses( + self, db: PrefectDBInterface + ) -> Sequence["ColumnExpressionArgument[bool]"]: + self._top_level_filter = self._scoped_event_resources(db) + result = super().build_where_clauses() + self._top_level_filter = None + return result - def _scoped_event_resources(self) -> Select: + def _scoped_event_resources(self, db: PrefectDBInterface) -> Select[tuple[UUID]]: """Returns an event_resources query that is scoped to this filter's scope by occurred.""" - query = sa.select(orm_models.EventResource.event_id).where( - orm_models.EventResource.occurred >= self.occurred.since, - orm_models.EventResource.occurred <= self.occurred.until, + query = sa.select(db.EventResource.event_id).where( + db.EventResource.occurred >= self.occurred.since, + db.EventResource.occurred <= self.occurred.until, ) return query diff --git a/src/prefect/server/events/models/automations.py b/src/prefect/server/events/models/automations.py index 25e306425bce5..af597d4bc5d68 100644 --- a/src/prefect/server/events/models/automations.py +++ b/src/prefect/server/events/models/automations.py @@ -1,13 +1,12 @@ from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, AsyncGenerator, Optional, Sequence, Union +from typing import AsyncGenerator, Optional, Sequence, Union from uuid import UUID import pendulum import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.events import filters from prefect.server.events.schemas.automations import ( Automation, @@ -18,9 +17,6 @@ from prefect.settings import PREFECT_API_SERVICES_TRIGGERS_ENABLED from prefect.utilities.asyncutils import run_coro_as_sync -if TYPE_CHECKING: - from prefect.server.database import orm_models - @asynccontextmanager @db_injector @@ -64,7 +60,7 @@ async def count_automations_for_workspace( db: PrefectDBInterface, session: AsyncSession, ) -> int: - query = sa.select(sa.func.count(sa.text("*"))).select_from(db.Automation) + query = sa.select(sa.func.count(None)).select_from(db.Automation) result = await session.execute(query) @@ -77,10 +73,9 @@ async def read_automation( session: AsyncSession, automation_id: UUID, ) -> Optional[Automation]: - result = await session.execute( + automation = await session.scalar( sa.select(db.Automation).where(db.Automation.id == automation_id) ) - automation: Optional[orm_models.Automation] = result.scalars().first() if not automation: return None return Automation.model_validate(automation, from_attributes=True) @@ -90,12 +85,11 @@ async def read_automation( async def read_automation_by_id( db: PrefectDBInterface, session: AsyncSession, automation_id: UUID ) -> Optional[Automation]: - result = await session.execute( + automation = await session.scalar( sa.select(db.Automation).where( db.Automation.id == automation_id, ) ) - automation: Optional[orm_models.Automation] = result.scalars().first() if not automation: return None return Automation.model_validate(automation, from_attributes=True) @@ -287,7 +281,7 @@ async def relate_automation_to_resource( owned_by_resource: bool, ) -> None: await session.execute( - db.insert(db.AutomationRelatedResource) + db.queries.insert(db.AutomationRelatedResource) .values( automation_id=automation_id, resource_id=resource_id, diff --git a/src/prefect/server/events/models/composite_trigger_child_firing.py b/src/prefect/server/events/models/composite_trigger_child_firing.py index bb7d0dc5c8ee3..748f3c2eb6e92 100644 --- a/src/prefect/server/events/models/composite_trigger_child_firing.py +++ b/src/prefect/server/events/models/composite_trigger_child_firing.py @@ -6,8 +6,7 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.events.schemas.automations import CompositeTrigger, Firing if TYPE_CHECKING: diff --git a/src/prefect/server/events/ordering.py b/src/prefect/server/events/ordering.py index 86c369e3ce008..c999ddf6c04f7 100644 --- a/src/prefect/server/events/ordering.py +++ b/src/prefect/server/events/ordering.py @@ -21,9 +21,7 @@ from cachetools import TTLCache from prefect.logging import get_logger -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface -from prefect.server.database.orm_models import AutomationEventFollower +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.events.schemas.events import Event, ReceivedEvent logger = get_logger(__name__) @@ -82,7 +80,7 @@ async def record_follower( async with db.session_context(begin_transaction=True) as session: await session.execute( - sa.insert(AutomationEventFollower).values( + sa.insert(db.AutomationEventFollower).values( scope=self.scope, leader_event_id=event.follows, follower_event_id=event.id, @@ -100,9 +98,9 @@ async def forget_follower( async with db.session_context(begin_transaction=True) as session: await session.execute( - sa.delete(AutomationEventFollower).where( - AutomationEventFollower.scope == self.scope, - AutomationEventFollower.follower_event_id == follower.id, + sa.delete(db.AutomationEventFollower).where( + db.AutomationEventFollower.scope == self.scope, + db.AutomationEventFollower.follower_event_id == follower.id, ) ) @@ -112,9 +110,9 @@ async def get_followers( ) -> List[ReceivedEvent]: """Returns events that were waiting on this leader event to arrive""" async with db.session_context() as session: - query = sa.select(AutomationEventFollower.follower).where( - AutomationEventFollower.scope == self.scope, - AutomationEventFollower.leader_event_id == leader.id, + query = sa.select(db.AutomationEventFollower.follower).where( + db.AutomationEventFollower.scope == self.scope, + db.AutomationEventFollower.leader_event_id == leader.id, ) result = await session.execute(query) followers = result.scalars().all() @@ -126,9 +124,9 @@ async def get_lost_followers(self, db: PrefectDBInterface) -> List[ReceivedEvent earlier = pendulum.now("UTC") - PRECEDING_EVENT_LOOKBACK async with db.session_context(begin_transaction=True) as session: - query = sa.select(AutomationEventFollower.follower).where( - AutomationEventFollower.scope == self.scope, - AutomationEventFollower.received < earlier, + query = sa.select(db.AutomationEventFollower.follower).where( + db.AutomationEventFollower.scope == self.scope, + db.AutomationEventFollower.received < earlier, ) result = await session.execute(query) followers = result.scalars().all() @@ -136,9 +134,9 @@ async def get_lost_followers(self, db: PrefectDBInterface) -> List[ReceivedEvent # forget these followers, since they are never going to see their leader event await session.execute( - sa.delete(AutomationEventFollower).where( - AutomationEventFollower.scope == self.scope, - AutomationEventFollower.received < earlier, + sa.delete(db.AutomationEventFollower).where( + db.AutomationEventFollower.scope == self.scope, + db.AutomationEventFollower.received < earlier, ) ) diff --git a/src/prefect/server/events/services/event_persister.py b/src/prefect/server/events/services/event_persister.py index b3f0fc1881756..2d810f37c2186 100644 --- a/src/prefect/server/events/services/event_persister.py +++ b/src/prefect/server/events/services/event_persister.py @@ -12,7 +12,7 @@ import sqlalchemy as sa from prefect.logging import get_logger -from prefect.server.database.dependencies import provide_database_interface +from prefect.server.database import provide_database_interface from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.events.storage.database import write_events from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer diff --git a/src/prefect/server/events/storage/database.py b/src/prefect/server/events/storage/database.py index ade565ff14f8e..a0b532e3a6a8d 100644 --- a/src/prefect/server/events/storage/database.py +++ b/src/prefect/server/events/storage/database.py @@ -6,8 +6,11 @@ from sqlalchemy.orm import aliased from prefect.logging.loggers import get_logger -from prefect.server.database.dependencies import db_injector, provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import ( + PrefectDBInterface, + db_injector, + provide_database_interface, +) from prefect.server.events.counting import Countable, TimeUnit from prefect.server.events.filters import EventFilter, EventOrder from prefect.server.events.schemas.events import EventCount, ReceivedEvent @@ -232,7 +235,7 @@ async def _write_sqlite_events( event for event in batch if event.id not in existing_event_ids ] event_rows = [event.as_database_row() for event in events_to_insert] - await session.execute(db.insert(db.Event).values(event_rows)) + await session.execute(db.queries.insert(db.Event).values(event_rows)) resource_rows: List[Dict[str, Any]] = [] for event in events_to_insert: @@ -241,7 +244,7 @@ async def _write_sqlite_events( if not resource_rows: continue - await session.execute(db.insert(db.EventResource).values(resource_rows)) + await session.execute(db.queries.insert(db.EventResource).values(resource_rows)) @db_injector @@ -258,7 +261,7 @@ async def _write_postgres_events( for batch in _in_safe_batches(events): event_rows = [event.as_database_row() for event in batch] result = await session.scalars( - db.insert(db.Event) + db.queries.insert(db.Event) .on_conflict_do_nothing() .returning(db.Event.id) .values(event_rows) @@ -277,7 +280,7 @@ async def _write_postgres_events( if not resource_rows: continue - await session.execute(db.insert(db.EventResource).values(resource_rows)) + await session.execute(db.queries.insert(db.EventResource).values(resource_rows)) def get_max_query_parameters() -> int: diff --git a/src/prefect/server/events/triggers.py b/src/prefect/server/events/triggers.py index 9a839bc5e66a2..ae25b8bd95dc2 100644 --- a/src/prefect/server/events/triggers.py +++ b/src/prefect/server/events/triggers.py @@ -24,8 +24,7 @@ from prefect._internal.retries import retry_async_fn from prefect.logging import get_logger -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.events import messaging from prefect.server.events.actions import ServerActionTypes from prefect.server.events.models.automations import ( @@ -803,7 +802,7 @@ async def increment_bucket( """Adds the given count to the bucket, returning the new bucket""" additional_updates: dict = {"last_event": last_event} if last_event else {} await session.execute( - db.insert(db.AutomationBucket) + db.queries.insert(db.AutomationBucket) .values( automation_id=bucket.automation_id, trigger_id=bucket.trigger_id, @@ -852,7 +851,7 @@ async def start_new_bucket( automation = trigger.automation await session.execute( - db.insert(db.AutomationBucket) + db.queries.insert(db.AutomationBucket) .values( automation_id=automation.id, trigger_id=trigger.id, @@ -904,7 +903,7 @@ async def ensure_bucket( automation = trigger.automation additional_updates: dict = {"last_event": last_event} if last_event else {} await session.execute( - db.insert(db.AutomationBucket) + db.queries.insert(db.AutomationBucket) .values( automation_id=automation.id, trigger_id=trigger.id, diff --git a/src/prefect/server/models/agents.py b/src/prefect/server/models/agents.py index f9ac3ed3e4438..bd12b1a55d063 100644 --- a/src/prefect/server/models/agents.py +++ b/src/prefect/server/models/agents.py @@ -12,12 +12,12 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.schemas as schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models +@db_injector async def create_agent( + db: PrefectDBInterface, session: AsyncSession, agent: schemas.core.Agent, ) -> orm_models.Agent: @@ -35,14 +35,16 @@ async def create_agent( """ - model = orm_models.Agent(**agent.model_dump()) + model = db.Agent(**agent.model_dump()) session.add(model) await session.flush() return model +@db_injector async def read_agent( + db: PrefectDBInterface, session: AsyncSession, agent_id: UUID, ) -> Union[orm_models.Agent, None]: @@ -57,10 +59,12 @@ async def read_agent( orm_models.Agent: the Agent """ - return await session.get(orm_models.Agent, agent_id) + return await session.get(db.Agent, agent_id) +@db_injector async def read_agents( + db: PrefectDBInterface, session: AsyncSession, offset: Union[int, None] = None, limit: Union[int, None] = None, @@ -77,7 +81,7 @@ async def read_agents( List[orm_models.Agent]: Agents """ - query = select(orm_models.Agent).order_by(orm_models.Agent.name) + query = select(db.Agent).order_by(db.Agent.name) if offset is not None: query = query.offset(offset) @@ -88,7 +92,9 @@ async def read_agents( return result.scalars().unique().all() +@db_injector async def update_agent( + db: PrefectDBInterface, session: AsyncSession, agent_id: UUID, agent: schemas.core.Agent, @@ -106,8 +112,8 @@ async def update_agent( """ update_stmt = ( - sa.update(orm_models.Agent) - .where(orm_models.Agent.id == agent_id) + sa.update(db.Agent) + .where(db.Agent.id == agent_id) # exclude_unset=True allows us to only update values provided by # the user, ignoring any defaults on the model .values(**agent.model_dump_for_orm(exclude_unset=True)) @@ -142,7 +148,7 @@ async def record_agent_poll( id=agent_id, work_queue_id=work_queue_id, last_activity_time=pendulum.now("UTC") ) insert_stmt = ( - db.insert(orm_models.Agent) + db.queries.insert(db.Agent) .values( **agent_data.model_dump( include={"id", "name", "work_queue_id", "last_activity_time"} @@ -158,7 +164,9 @@ async def record_agent_poll( await session.execute(insert_stmt) +@db_injector async def delete_agent( + db: PrefectDBInterface, session: AsyncSession, agent_id: UUID, ) -> bool: @@ -173,7 +181,5 @@ async def delete_agent( bool: whether or not the Agent was deleted """ - result = await session.execute( - delete(orm_models.Agent).where(orm_models.Agent.id == agent_id) - ) + result = await session.execute(delete(db.Agent).where(db.Agent.id == agent_id)) return result.rowcount > 0 diff --git a/src/prefect/server/models/artifacts.py b/src/prefect/server/models/artifacts.py index 772d5abe5bb9c..b009683e06fe3 100644 --- a/src/prefect/server/models/artifacts.py +++ b/src/prefect/server/models/artifacts.py @@ -7,9 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import Select -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.schemas import actions, filters, sorting from prefect.server.schemas.core import Artifact @@ -30,10 +28,10 @@ async def _insert_into_artifact_collection( exclude_unset=True, exclude={"id", "updated", "created"} ) upsert_new_latest_id = ( - db.insert(orm_models.ArtifactCollection) + db.queries.insert(db.ArtifactCollection) .values(latest_id=artifact.id, updated=now, created=now, **insert_values) .on_conflict_do_update( - index_elements=db.artifact_collection_unique_upsert_columns, + index_elements=db.orm.artifact_collection_unique_upsert_columns, set_=dict( latest_id=artifact.id, updated=now, @@ -45,12 +43,8 @@ async def _insert_into_artifact_collection( await session.execute(upsert_new_latest_id) query = ( - sa.select(orm_models.ArtifactCollection) - .where( - sa.and_( - orm_models.ArtifactCollection.key == artifact.key, - ) - ) + sa.select(db.ArtifactCollection) + .where(sa.and_(db.ArtifactCollection.key == artifact.key)) .execution_options(populate_existing=True) ) @@ -84,7 +78,7 @@ async def _insert_into_artifact( Inserts a new artifact into the artifact table. """ artifact_id = artifact.id - insert_stmt = db.insert(orm_models.Artifact).values( + insert_stmt = db.queries.insert(db.Artifact).values( created=now, updated=now, **artifact.model_dump_for_orm(exclude={"created", "updated"}), @@ -92,8 +86,8 @@ async def _insert_into_artifact( await session.execute(insert_stmt) query = ( - sa.select(orm_models.Artifact) - .where(orm_models.Artifact.id == artifact_id) + sa.select(db.Artifact) + .where(db.Artifact.id == artifact_id) .limit(1) .execution_options(populate_existing=True) ) @@ -122,7 +116,9 @@ async def create_artifact( return result +@db_injector async def read_latest_artifact( + db: PrefectDBInterface, session: AsyncSession, key: str, ) -> Union[orm_models.ArtifactCollection, None]: @@ -134,14 +130,16 @@ async def read_latest_artifact( Returns: Artifact: The latest artifact """ - latest_artifact_query = sa.select(orm_models.ArtifactCollection).where( - orm_models.ArtifactCollection.key == key + latest_artifact_query = sa.select(db.ArtifactCollection).where( + db.ArtifactCollection.key == key ) result = await session.execute(latest_artifact_query) return result.scalar() +@db_injector async def read_artifact( + db: PrefectDBInterface, session: AsyncSession, artifact_id: UUID, ) -> Union[orm_models.Artifact, None]: @@ -149,13 +147,14 @@ async def read_artifact( Reads an artifact by id. """ - query = sa.select(orm_models.Artifact).where(orm_models.Artifact.id == artifact_id) + query = sa.select(db.Artifact).where(db.Artifact.id == artifact_id) result = await session.execute(query) return result.scalar() async def _apply_artifact_filters( + db: PrefectDBInterface, query: Select[T], flow_run_filter: Optional[filters.FlowRunFilter] = None, task_run_filter: Optional[filters.TaskRunFilter] = None, @@ -168,8 +167,8 @@ async def _apply_artifact_filters( query = query.where(artifact_filter.as_sql_filter()) if flow_filter or flow_run_filter or deployment_filter: - flow_run_exists_clause = select(orm_models.FlowRun).where( - orm_models.Artifact.flow_run_id == orm_models.FlowRun.id + flow_run_exists_clause = select(db.FlowRun).where( + db.Artifact.flow_run_id == db.FlowRun.id ) if flow_run_filter: flow_run_exists_clause = flow_run_exists_clause.where( @@ -178,21 +177,19 @@ async def _apply_artifact_filters( if flow_filter: flow_run_exists_clause = flow_run_exists_clause.join( - orm_models.Flow, - orm_models.Flow.id == orm_models.FlowRun.flow_id, + db.Flow, db.Flow.id == db.FlowRun.flow_id ).where(flow_filter.as_sql_filter()) if deployment_filter: flow_run_exists_clause = flow_run_exists_clause.join( - orm_models.Deployment, - orm_models.Deployment.id == orm_models.FlowRun.deployment_id, + db.Deployment, db.Deployment.id == db.FlowRun.deployment_id ).where(deployment_filter.as_sql_filter()) query = query.where(flow_run_exists_clause.exists()) if task_run_filter: - task_run_exists_clause = select(orm_models.TaskRun).where( - orm_models.Artifact.task_run_id == orm_models.TaskRun.id + task_run_exists_clause = select(db.TaskRun).where( + db.Artifact.task_run_id == db.TaskRun.id ) task_run_exists_clause = task_run_exists_clause.where( task_run_filter.as_sql_filter() @@ -204,6 +201,7 @@ async def _apply_artifact_filters( async def _apply_artifact_collection_filters( + db: PrefectDBInterface, query: Select[T], flow_run_filter: Optional[filters.FlowRunFilter] = None, task_run_filter: Optional[filters.TaskRunFilter] = None, @@ -216,8 +214,8 @@ async def _apply_artifact_collection_filters( query = query.where(artifact_filter.as_sql_filter()) if flow_filter or flow_run_filter or deployment_filter: - flow_run_exists_clause = select(orm_models.FlowRun).where( - orm_models.ArtifactCollection.flow_run_id == orm_models.FlowRun.id + flow_run_exists_clause = select(db.FlowRun).where( + db.ArtifactCollection.flow_run_id == db.FlowRun.id ) if flow_run_filter: flow_run_exists_clause = flow_run_exists_clause.where( @@ -226,21 +224,19 @@ async def _apply_artifact_collection_filters( if flow_filter: flow_run_exists_clause = flow_run_exists_clause.join( - orm_models.Flow, - orm_models.Flow.id == orm_models.FlowRun.flow_id, + db.Flow, db.Flow.id == db.FlowRun.flow_id ).where(flow_filter.as_sql_filter()) if deployment_filter: flow_run_exists_clause = flow_run_exists_clause.join( - orm_models.Deployment, - orm_models.Deployment.id == orm_models.FlowRun.deployment_id, + db.Deployment, db.Deployment.id == db.FlowRun.deployment_id ).where(deployment_filter.as_sql_filter()) query = query.where(flow_run_exists_clause.exists()) if task_run_filter: - task_run_exists_clause = select(orm_models.TaskRun).where( - orm_models.ArtifactCollection.task_run_id == orm_models.TaskRun.id + task_run_exists_clause = select(db.TaskRun).where( + db.ArtifactCollection.task_run_id == db.TaskRun.id ) task_run_exists_clause = task_run_exists_clause.where( task_run_filter.as_sql_filter() @@ -251,7 +247,9 @@ async def _apply_artifact_collection_filters( return query +@db_injector async def read_artifacts( + db: PrefectDBInterface, session: AsyncSession, offset: Optional[int] = None, limit: Optional[int] = None, @@ -276,9 +274,10 @@ async def read_artifacts( flow_filter: Only select artifacts whose flow runs belong to flows matching this filter work_pool_filter: Only select artifacts whose flow runs belong to work pools matching this filter """ - query = sa.select(orm_models.Artifact).order_by(sort.as_sql_sort()) + query = sa.select(db.Artifact).order_by(*sort.as_sql_sort()) query = await _apply_artifact_filters( + db, query, artifact_filter=artifact_filter, flow_run_filter=flow_run_filter, @@ -296,7 +295,9 @@ async def read_artifacts( return result.scalars().unique().all() +@db_injector async def read_latest_artifacts( + db: PrefectDBInterface, session: AsyncSession, offset: Optional[int] = None, limit: Optional[int] = None, @@ -321,8 +322,9 @@ async def read_latest_artifacts( flow_filter: Only select artifacts whose flow runs belong to flows matching this filter work_pool_filter: Only select artifacts whose flow runs belong to work pools matching this filter """ - query = sa.select(orm_models.ArtifactCollection).order_by(sort.as_sql_sort()) + query = sa.select(db.ArtifactCollection).order_by(*sort.as_sql_sort()) query = await _apply_artifact_collection_filters( + db, query, artifact_filter=artifact_filter, flow_run_filter=flow_run_filter, @@ -340,7 +342,9 @@ async def read_latest_artifacts( return result.scalars().unique().all() +@db_injector async def count_artifacts( + db: PrefectDBInterface, session: AsyncSession, artifact_filter: Optional[filters.ArtifactFilter] = None, flow_run_filter: Optional[filters.FlowRunFilter] = None, @@ -356,9 +360,10 @@ async def count_artifacts( flow_run_filter: Only select artifacts whose flow runs matching this filter task_run_filter: Only select artifacts whose task runs matching this filter """ - query = sa.select(sa.func.count(orm_models.Artifact.id)) + query = sa.select(sa.func.count(db.Artifact.id)) query = await _apply_artifact_filters( + db, query, artifact_filter=artifact_filter, flow_run_filter=flow_run_filter, @@ -371,7 +376,9 @@ async def count_artifacts( return result.scalar_one() +@db_injector async def count_latest_artifacts( + db: PrefectDBInterface, session: AsyncSession, artifact_filter: Optional[filters.ArtifactCollectionFilter] = None, flow_run_filter: Optional[filters.FlowRunFilter] = None, @@ -387,9 +394,10 @@ async def count_latest_artifacts( flow_run_filter: Only select artifacts whose flow runs matching this filter task_run_filter: Only select artifacts whose task runs matching this filter """ - query = sa.select(sa.func.count(orm_models.ArtifactCollection.id)) + query = sa.select(sa.func.count(db.ArtifactCollection.id)) query = await _apply_artifact_collection_filters( + db, query, artifact_filter=artifact_filter, flow_run_filter=flow_run_filter, @@ -402,7 +410,9 @@ async def count_latest_artifacts( return result.scalar_one() +@db_injector async def update_artifact( + db: PrefectDBInterface, session: AsyncSession, artifact_id: UUID, artifact: actions.ArtifactUpdate, @@ -421,8 +431,8 @@ async def update_artifact( update_artifact_data = artifact.model_dump_for_orm(exclude_unset=True) update_artifact_stmt = ( - sa.update(orm_models.Artifact) - .where(orm_models.Artifact.id == artifact_id) + sa.update(db.Artifact) + .where(db.Artifact.id == artifact_id) .values(**update_artifact_data) ) @@ -430,8 +440,8 @@ async def update_artifact( update_artifact_collection_data = artifact.model_dump_for_orm(exclude_unset=True) update_artifact_collection_stmt = ( - sa.update(orm_models.ArtifactCollection) - .where(orm_models.ArtifactCollection.latest_id == artifact_id) + sa.update(db.ArtifactCollection) + .where(db.ArtifactCollection.latest_id == artifact_id) .values(**update_artifact_collection_data) ) collection_result = await session.execute(update_artifact_collection_stmt) @@ -439,9 +449,9 @@ async def update_artifact( return artifact_result.rowcount + collection_result.rowcount > 0 +@db_injector async def delete_artifact( - session: AsyncSession, - artifact_id: UUID, + db: PrefectDBInterface, session: AsyncSession, artifact_id: UUID ) -> bool: """ Deletes an artifact by id. @@ -470,33 +480,33 @@ async def delete_artifact( Returns: bool: True if the delete was successful, False otherwise """ - artifact = await session.get(orm_models.Artifact, artifact_id) + artifact = await session.get(db.Artifact, artifact_id) if artifact is None: return False is_latest_version = ( await session.execute( - sa.select(orm_models.ArtifactCollection) - .where(orm_models.ArtifactCollection.key == artifact.key) - .where(orm_models.ArtifactCollection.latest_id == artifact_id) + sa.select(db.ArtifactCollection) + .where(db.ArtifactCollection.key == artifact.key) + .where(db.ArtifactCollection.latest_id == artifact_id) ) ).scalar_one_or_none() is not None if is_latest_version: next_latest_version = ( await session.execute( - sa.select(orm_models.Artifact) - .where(orm_models.Artifact.key == artifact.key) - .where(orm_models.Artifact.id != artifact_id) - .order_by(orm_models.Artifact.created.desc()) + sa.select(db.Artifact) + .where(db.Artifact.key == artifact.key) + .where(db.Artifact.id != artifact_id) + .order_by(db.Artifact.created.desc()) .limit(1) ) ).scalar_one_or_none() if next_latest_version is not None: set_next_latest_version = ( - sa.update(orm_models.ArtifactCollection) - .where(orm_models.ArtifactCollection.key == artifact.key) + sa.update(db.ArtifactCollection) + .where(db.ArtifactCollection.key == artifact.key) .values( latest_id=next_latest_version.id, data=next_latest_version.data, @@ -513,14 +523,12 @@ async def delete_artifact( else: await session.execute( - sa.delete(orm_models.ArtifactCollection) - .where(orm_models.ArtifactCollection.key == artifact.key) - .where(orm_models.ArtifactCollection.latest_id == artifact_id) + sa.delete(db.ArtifactCollection) + .where(db.ArtifactCollection.key == artifact.key) + .where(db.ArtifactCollection.latest_id == artifact_id) ) - delete_stmt = sa.delete(orm_models.Artifact).where( - orm_models.Artifact.id == artifact_id - ) + delete_stmt = sa.delete(db.Artifact).where(db.Artifact.id == artifact_id) result = await session.execute(delete_stmt) return result.rowcount > 0 diff --git a/src/prefect/server/models/block_documents.py b/src/prefect/server/models/block_documents.py index 5fe00f0ce2784..638d2cb9183ec 100644 --- a/src/prefect/server/models/block_documents.py +++ b/src/prefect/server/models/block_documents.py @@ -13,9 +13,7 @@ import prefect.server.models as models from prefect.server import schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.schemas.actions import BlockDocumentReferenceCreate from prefect.server.schemas.core import BlockDocument from prefect.server.schemas.filters import BlockSchemaFilter @@ -26,7 +24,9 @@ T = TypeVar("T", bound=tuple) +@db_injector async def create_block_document( + db: PrefectDBInterface, session: AsyncSession, block_document: schemas.actions.BlockDocumentCreate, ) -> BlockDocument: @@ -43,7 +43,7 @@ async def create_block_document( else: name = block_document.name - orm_block = orm_models.BlockDocument( + orm_block = db.BlockDocument( name=name, block_schema_id=block_document.block_schema_id, block_type_id=block_document.block_type_id, @@ -85,13 +85,14 @@ async def create_block_document( return new_block_document +@db_injector async def block_document_with_unique_values_exists( - session: AsyncSession, block_type_id: UUID, name: str + db: PrefectDBInterface, session: AsyncSession, block_type_id: UUID, name: str ) -> bool: result = await session.execute( - sa.select(sa.exists(orm_models.BlockDocument)).where( - orm_models.BlockDocument.block_type_id == block_type_id, - orm_models.BlockDocument.name == name, + sa.select(sa.exists(db.BlockDocument)).where( + db.BlockDocument.block_type_id == block_type_id, + db.BlockDocument.name == name, ) ) return bool(result.scalar_one_or_none()) @@ -152,6 +153,7 @@ async def read_block_document_by_id( async def _construct_full_block_document( + db: PrefectDBInterface, session: AsyncSession, block_documents_with_references: Sequence[ Tuple[orm_models.ORMBlockDocument, Optional[str], Optional[UUID]] @@ -186,6 +188,7 @@ async def _construct_full_block_document( session, orm_block_document, include_secrets=include_secrets ) full_child_block_document = await _construct_full_block_document( + db, session, block_documents_with_references, parent_block_document=copy(block_document), @@ -264,6 +267,7 @@ async def read_block_document_by_name( def _apply_block_document_filters( + db: PrefectDBInterface, query: Select[T], block_document_filter: Optional[schemas.filters.BlockDocumentFilter] = None, block_schema_filter: Optional[schemas.filters.BlockSchemaFilter] = None, @@ -279,15 +283,15 @@ def _apply_block_document_filters( query = query.where(block_document_filter.as_sql_filter()) if block_type_filter is not None: - block_type_exists_clause = sa.select(orm_models.BlockType).where( - orm_models.BlockType.id == orm_models.BlockDocument.block_type_id, + block_type_exists_clause = sa.select(db.BlockType).where( + db.BlockType.id == db.BlockDocument.block_type_id, block_type_filter.as_sql_filter(), ) query = query.where(block_type_exists_clause.exists()) if block_schema_filter is not None: - block_schema_exists_clause = sa.select(orm_models.BlockSchema).where( - orm_models.BlockSchema.id == orm_models.BlockDocument.block_schema_id, + block_schema_exists_clause = sa.select(db.BlockSchema).where( + db.BlockSchema.id == db.BlockDocument.block_schema_id, block_schema_filter.as_sql_filter(), ) query = query.where(block_schema_exists_clause.exists()) @@ -295,7 +299,9 @@ def _apply_block_document_filters( return query +@db_injector async def read_block_documents( + db: PrefectDBInterface, session: AsyncSession, block_document_filter: Optional[schemas.filters.BlockDocumentFilter] = None, block_type_filter: Optional[schemas.filters.BlockTypeFilter] = None, @@ -309,15 +315,16 @@ async def read_block_documents( Read block documents with an optional limit and offset """ # --- Build an initial query that filters for the requested block documents - filtered_block_documents_query = sa.select(orm_models.BlockDocument.id) + filtered_block_documents_query = sa.select(db.BlockDocument.id) filtered_block_documents_query = _apply_block_document_filters( + db, query=filtered_block_documents_query, block_document_filter=block_document_filter, block_type_filter=block_type_filter, block_schema_filter=block_schema_filter, ) filtered_block_documents_query = filtered_block_documents_query.order_by( - sort.as_sql_sort() + *sort.as_sql_sort() ) if offset is not None: @@ -348,15 +355,14 @@ async def read_block_documents( # recursive part of query referenced_documents = ( sa.select( - orm_models.BlockDocumentReference.reference_block_document_id, - orm_models.BlockDocumentReference.name, - orm_models.BlockDocumentReference.parent_block_document_id, + db.BlockDocumentReference.reference_block_document_id, + db.BlockDocumentReference.name, + db.BlockDocumentReference.parent_block_document_id, ) .select_from(parent_documents) .join( - orm_models.BlockDocumentReference, - orm_models.BlockDocumentReference.parent_block_document_id - == parent_documents.c.id, + db.BlockDocumentReference, + db.BlockDocumentReference.parent_block_document_id == parent_documents.c.id, ) ) # union the recursive CTE @@ -367,16 +373,13 @@ async def read_block_documents( # and order by name final_query = ( sa.select( - orm_models.BlockDocument, + db.BlockDocument, all_block_documents_query.c.reference_name, all_block_documents_query.c.reference_parent_block_document_id, ) .select_from(all_block_documents_query) - .join( - orm_models.BlockDocument, - orm_models.BlockDocument.id == all_block_documents_query.c.id, - ) - .order_by(sort.as_sql_sort()) + .join(db.BlockDocument, db.BlockDocument.id == all_block_documents_query.c.id) + .order_by(*sort.as_sql_sort()) ) result = await session.execute( @@ -404,6 +407,7 @@ async def read_block_documents( session, root_orm_block_document, include_secrets=include_secrets ) constructed = await _construct_full_block_document( + db, session, block_documents_with_references, # type: ignore root_block_document, @@ -433,7 +437,9 @@ async def read_block_documents( return fully_constructed_block_documents +@db_injector async def count_block_documents( + db: PrefectDBInterface, session: AsyncSession, block_document_filter: Optional[schemas.filters.BlockDocumentFilter] = None, block_type_filter: Optional[schemas.filters.BlockTypeFilter] = None, @@ -442,9 +448,10 @@ async def count_block_documents( """ Count block documents that match the filters. """ - query = sa.select(sa.func.count()).select_from(orm_models.BlockDocument) + query = sa.select(sa.func.count()).select_from(db.BlockDocument) query = _apply_block_document_filters( + db, query=query, block_document_filter=block_document_filter, block_schema_filter=block_schema_filter, @@ -455,26 +462,26 @@ async def count_block_documents( return result.scalar() # type: ignore +@db_injector async def delete_block_document( + db: PrefectDBInterface, session: AsyncSession, block_document_id: UUID, ) -> bool: - query = sa.delete(orm_models.BlockDocument).where( - orm_models.BlockDocument.id == block_document_id - ) + query = sa.delete(db.BlockDocument).where(db.BlockDocument.id == block_document_id) result = await session.execute(query) return result.rowcount > 0 +@db_injector async def update_block_document( + db: PrefectDBInterface, session: AsyncSession, block_document_id: UUID, block_document: schemas.actions.BlockDocumentUpdate, ) -> bool: merge_existing_data = block_document.merge_existing_data - current_block_document = await session.get( - orm_models.BlockDocument, block_document_id - ) + current_block_document = await session.get(db.BlockDocument, block_document_id) if not current_block_document: return False @@ -525,7 +532,7 @@ async def update_block_document( current_block_document_references = ( ( await session.execute( - sa.select(orm_models.BlockDocumentReference).filter_by( + sa.select(db.BlockDocumentReference).filter_by( parent_block_document_id=block_document_id ) ) @@ -554,7 +561,7 @@ async def update_block_document( and proposed_block_schema_id != current_block_document.block_schema_id ): proposed_block_schema = await session.get( - orm_models.BlockSchema, proposed_block_schema_id + db.BlockSchema, proposed_block_schema_id ) assert ( proposed_block_schema @@ -570,8 +577,8 @@ async def update_block_document( " type." ) await session.execute( - sa.update(orm_models.BlockDocument) - .where(orm_models.BlockDocument.id == block_document_id) + sa.update(db.BlockDocument) + .where(db.BlockDocument.id == block_document_id) .values(block_schema_id=proposed_block_schema_id) ) @@ -628,7 +635,7 @@ async def create_block_document_reference( session: AsyncSession, block_document_reference: schemas.actions.BlockDocumentReferenceCreate, ) -> Union[orm_models.BlockDocumentReference, None]: - insert_stmt = db.insert(orm_models.BlockDocumentReference).values( + insert_stmt = db.queries.insert(db.BlockDocumentReference).values( **block_document_reference.model_dump_for_orm( exclude_unset=True, exclude={"created", "updated"} ) @@ -636,20 +643,22 @@ async def create_block_document_reference( await session.execute(insert_stmt) result = await session.execute( - sa.select(orm_models.BlockDocumentReference).where( - orm_models.BlockDocumentReference.id == block_document_reference.id + sa.select(db.BlockDocumentReference).where( + db.BlockDocumentReference.id == block_document_reference.id ) ) return result.scalar() +@db_injector async def delete_block_document_reference( + db: PrefectDBInterface, session: AsyncSession, block_document_reference_id: UUID, ) -> bool: - query = sa.delete(orm_models.BlockDocumentReference).where( - orm_models.BlockDocumentReference.id == block_document_reference_id + query = sa.delete(db.BlockDocumentReference).where( + db.BlockDocumentReference.id == block_document_reference_id ) result = await session.execute(query) return result.rowcount > 0 diff --git a/src/prefect/server/models/block_schemas.py b/src/prefect/server/models/block_schemas.py index 00ed0b21e3296..264dc4913fc87 100644 --- a/src/prefect/server/models/block_schemas.py +++ b/src/prefect/server/models/block_schemas.py @@ -13,9 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.models.block_types import read_block_type_by_slug from prefect.server.schemas.actions import BlockSchemaCreate from prefect.server.schemas.core import BlockSchema, BlockSchemaReference @@ -114,31 +112,32 @@ async def create_block_schema( "block_schema_references", {} ) - insert_stmt = db.insert(orm_models.BlockSchema).values(**insert_values) + insert_stmt = db.queries.insert(db.BlockSchema).values(**insert_values) if override: insert_stmt = insert_stmt.on_conflict_do_update( - index_elements=db.block_schema_unique_upsert_columns, + index_elements=db.orm.block_schema_unique_upsert_columns, set_=insert_values, ) await session.execute(insert_stmt) query = ( - sa.select(orm_models.BlockSchema) + sa.select(db.BlockSchema) .where( - orm_models.BlockSchema.checksum == insert_values["checksum"], + db.BlockSchema.checksum == insert_values["checksum"], ) - .order_by(orm_models.BlockSchema.created.desc()) + .order_by(db.BlockSchema.created.desc()) .limit(1) .execution_options(populate_existing=True) ) if block_schema.version is not None: - query = query.where(orm_models.BlockSchema.version == block_schema.version) + query = query.where(db.BlockSchema.version == block_schema.version) result = await session.execute(query) created_block_schema = copy(result.scalar_one()) await _register_nested_block_schemas( + db, session=session, parent_block_schema_id=created_block_schema.id, block_schema_references=block_schema_references, @@ -155,6 +154,7 @@ async def create_block_schema( async def _register_nested_block_schemas( + db: PrefectDBInterface, session: AsyncSession, parent_block_schema_id: UUID, block_schema_references: Dict[str, Union[Dict[str, str], List[Dict[str, str]]]], @@ -213,7 +213,7 @@ async def _register_nested_block_schemas( " definitions in root block schema fields" ) sub_block_schema_fields = _get_fields_for_child_schema( - definitions, base_fields, reference_name, reference_block_type + db, definitions, base_fields, reference_name, reference_block_type ) if sub_block_schema_fields is None: @@ -243,6 +243,7 @@ async def _register_nested_block_schemas( def _get_fields_for_child_schema( + db: PrefectDBInterface, definitions: Dict, base_fields: Dict, reference_name: str, @@ -281,7 +282,10 @@ def _get_fields_for_child_schema( return sub_block_schema_fields # type: ignore -async def delete_block_schema(session: AsyncSession, block_schema_id: UUID) -> bool: +@db_injector +async def delete_block_schema( + db: PrefectDBInterface, session: AsyncSession, block_schema_id: UUID +) -> bool: """ Delete a block schema by id. @@ -294,14 +298,14 @@ async def delete_block_schema(session: AsyncSession, block_schema_id: UUID) -> b """ result = await session.execute( - delete(orm_models.BlockSchema).where( - orm_models.BlockSchema.id == block_schema_id - ) + delete(db.BlockSchema).where(db.BlockSchema.id == block_schema_id) ) return result.rowcount > 0 +@db_injector async def read_block_schema( + db: PrefectDBInterface, session: AsyncSession, block_schema_id: UUID, ) -> Union[BlockSchema, None]: @@ -321,17 +325,17 @@ async def read_block_schema( # along with and nested block schemas coupled with the ID of their parent schema # the key that they reside under. block_schema_references_query = ( - sa.select(orm_models.BlockSchemaReference) - .select_from(orm_models.BlockSchemaReference) + sa.select(db.BlockSchemaReference) + .select_from(db.BlockSchemaReference) .filter_by(parent_block_schema_id=block_schema_id) .cte("block_schema_references", recursive=True) ) block_schema_references_join = ( - sa.select(orm_models.BlockSchemaReference) - .select_from(orm_models.BlockSchemaReference) + sa.select(db.BlockSchemaReference) + .select_from(db.BlockSchemaReference) .join( block_schema_references_query, - orm_models.BlockSchemaReference.parent_block_schema_id + db.BlockSchemaReference.parent_block_schema_id == block_schema_references_query.c.reference_block_schema_id, ) ) @@ -340,20 +344,20 @@ async def read_block_schema( ) nested_block_schemas_query = ( sa.select( - orm_models.BlockSchema, + db.BlockSchema, recursive_block_schema_references_cte.c.name, recursive_block_schema_references_cte.c.parent_block_schema_id, ) - .select_from(orm_models.BlockSchema) + .select_from(db.BlockSchema) .join( recursive_block_schema_references_cte, - orm_models.BlockSchema.id + db.BlockSchema.id == recursive_block_schema_references_cte.c.reference_block_schema_id, isouter=True, ) .filter( sa.or_( - orm_models.BlockSchema.id == block_schema_id, + db.BlockSchema.id == block_schema_id, recursive_block_schema_references_cte.c.parent_block_schema_id.is_not( None ), @@ -584,7 +588,9 @@ def _construct_block_schema_fields_with_block_references( return block_schema_fields_copy +@db_injector async def read_block_schemas( + db: PrefectDBInterface, session: AsyncSession, block_schema_filter: Optional[schemas.filters.BlockSchemaFilter] = None, limit: Optional[int] = None, @@ -604,8 +610,8 @@ async def read_block_schemas( """ # schemas are ordered by `created DESC` to get the most recently created # ones first (and to facilitate getting the newest one with `limit=1`). - filtered_block_schemas_query = select(orm_models.BlockSchema.id).order_by( - orm_models.BlockSchema.created.desc() + filtered_block_schemas_query = select(db.BlockSchema.id).order_by( + db.BlockSchema.created.desc() ) if block_schema_filter: @@ -623,21 +629,21 @@ async def read_block_schemas( ) block_schema_references_query = ( - sa.select(orm_models.BlockSchemaReference) - .select_from(orm_models.BlockSchemaReference) + sa.select(db.BlockSchemaReference) + .select_from(db.BlockSchemaReference) .filter( - orm_models.BlockSchemaReference.parent_block_schema_id.in_( + db.BlockSchemaReference.parent_block_schema_id.in_( filtered_block_schemas_query ) ) .cte("block_schema_references", recursive=True) ) block_schema_references_join = ( - sa.select(orm_models.BlockSchemaReference) - .select_from(orm_models.BlockSchemaReference) + sa.select(db.BlockSchemaReference) + .select_from(db.BlockSchemaReference) .join( block_schema_references_query, - orm_models.BlockSchemaReference.parent_block_schema_id + db.BlockSchemaReference.parent_block_schema_id == block_schema_references_query.c.reference_block_schema_id, ) ) @@ -647,24 +653,24 @@ async def read_block_schemas( nested_block_schemas_query = ( sa.select( - orm_models.BlockSchema, + db.BlockSchema, recursive_block_schema_references_cte.c.name, recursive_block_schema_references_cte.c.parent_block_schema_id, ) - .select_from(orm_models.BlockSchema) + .select_from(db.BlockSchema) # in order to reconstruct nested block schemas efficiently, we need to visit them # in the order they were created (so that we guarantee that nested/referenced schemas) # have already been seen. Therefore this second query sorts by created ASC - .order_by(orm_models.BlockSchema.created.asc()) + .order_by(db.BlockSchema.created.asc()) .join( recursive_block_schema_references_cte, - orm_models.BlockSchema.id + db.BlockSchema.id == recursive_block_schema_references_cte.c.reference_block_schema_id, isouter=True, ) .filter( sa.or_( - orm_models.BlockSchema.id.in_(filtered_block_schemas_query), + db.BlockSchema.id.in_(filtered_block_schemas_query), recursive_block_schema_references_cte.c.parent_block_schema_id.is_not( None ), @@ -695,7 +701,9 @@ async def read_block_schemas( return list(reversed(fully_constructed_block_schemas)) +@db_injector async def read_block_schema_by_checksum( + db: PrefectDBInterface, session: AsyncSession, checksum: str, version: Optional[str] = None, @@ -719,9 +727,9 @@ async def read_block_schema_by_checksum( # The same checksum with different versions can occur in the DB. Return only the # most recently created one. root_block_schema_query = ( - sa.select(orm_models.BlockSchema) + sa.select(db.BlockSchema) .filter_by(checksum=checksum) - .order_by(orm_models.BlockSchema.created.desc()) + .order_by(db.BlockSchema.created.desc()) .limit(1) ) @@ -731,17 +739,17 @@ async def read_block_schema_by_checksum( root_block_schema_cte = root_block_schema_query.cte("root_block_schema") block_schema_references_query = ( - sa.select(orm_models.BlockSchemaReference) - .select_from(orm_models.BlockSchemaReference) + sa.select(db.BlockSchemaReference) + .select_from(db.BlockSchemaReference) .filter_by(parent_block_schema_id=root_block_schema_cte.c.id) .cte("block_schema_references", recursive=True) ) block_schema_references_join = ( - sa.select(orm_models.BlockSchemaReference) - .select_from(orm_models.BlockSchemaReference) + sa.select(db.BlockSchemaReference) + .select_from(db.BlockSchemaReference) .join( block_schema_references_query, - orm_models.BlockSchemaReference.parent_block_schema_id + db.BlockSchemaReference.parent_block_schema_id == block_schema_references_query.c.reference_block_schema_id, ) ) @@ -750,20 +758,20 @@ async def read_block_schema_by_checksum( ) nested_block_schemas_query = ( sa.select( - orm_models.BlockSchema, + db.BlockSchema, recursive_block_schema_references_cte.c.name, recursive_block_schema_references_cte.c.parent_block_schema_id, ) - .select_from(orm_models.BlockSchema) + .select_from(db.BlockSchema) .join( recursive_block_schema_references_cte, - orm_models.BlockSchema.id + db.BlockSchema.id == recursive_block_schema_references_cte.c.reference_block_schema_id, isouter=True, ) .filter( sa.or_( - orm_models.BlockSchema.id == root_block_schema_cte.c.id, + db.BlockSchema.id == root_block_schema_cte.c.id, recursive_block_schema_references_cte.c.parent_block_schema_id.is_not( None ), @@ -789,10 +797,12 @@ async def read_available_block_capabilities( List[str]: List of all available block capabilities. """ query = sa.select( - db.json_arr_agg(db.cast_to_json(orm_models.BlockSchema.capabilities.distinct())) + db.queries.json_arr_agg( + db.queries.cast_to_json(db.BlockSchema.capabilities.distinct()) + ) ) capability_combinations = (await session.execute(query)).scalars().first() or list() - if db.uses_json_strings and isinstance(capability_combinations, str): + if db.queries.uses_json_strings and isinstance(capability_combinations, str): capability_combinations = json.loads(capability_combinations) return list({c for capabilities in capability_combinations for c in capabilities}) @@ -813,11 +823,11 @@ async def create_block_schema_reference( Returns: orm_models.BlockSchemaReference: The created BlockSchemaReference """ - query_stmt = sa.select(orm_models.BlockSchemaReference).where( - orm_models.BlockSchemaReference.name == block_schema_reference.name, - orm_models.BlockSchemaReference.parent_block_schema_id + query_stmt = sa.select(db.BlockSchemaReference).where( + db.BlockSchemaReference.name == block_schema_reference.name, + db.BlockSchemaReference.parent_block_schema_id == block_schema_reference.parent_block_schema_id, - orm_models.BlockSchemaReference.reference_block_schema_id + db.BlockSchemaReference.reference_block_schema_id == block_schema_reference.reference_block_schema_id, ) @@ -825,7 +835,7 @@ async def create_block_schema_reference( if existing_reference: return existing_reference - insert_stmt = db.insert(orm_models.BlockSchemaReference).values( + insert_stmt = db.queries.insert(db.BlockSchemaReference).values( **block_schema_reference.model_dump_for_orm( exclude_unset=True, exclude={"created", "updated"} ) @@ -833,8 +843,8 @@ async def create_block_schema_reference( await session.execute(insert_stmt) result = await session.execute( - sa.select(orm_models.BlockSchemaReference).where( - orm_models.BlockSchemaReference.id == block_schema_reference.id + sa.select(db.BlockSchemaReference).where( + db.BlockSchemaReference.id == block_schema_reference.id ) ) return result.scalar() diff --git a/src/prefect/server/models/block_types.py b/src/prefect/server/models/block_types.py index 18300ac9be853..abbb2eb257a45 100644 --- a/src/prefect/server/models/block_types.py +++ b/src/prefect/server/models/block_types.py @@ -11,9 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import schemas -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface -from prefect.server.database.orm_models import BlockSchema, BlockType +from prefect.server.database import PrefectDBInterface, db_injector +from prefect.server.database.orm_models import BlockType if TYPE_CHECKING: from prefect.client.schemas import BlockType as ClientBlockType @@ -56,19 +55,19 @@ async def create_block_type( insert_values["code_example"] = html.escape( insert_values["code_example"], quote=False ) - insert_stmt = db.insert(BlockType).values(**insert_values) + insert_stmt = db.queries.insert(db.BlockType).values(**insert_values) if override: insert_stmt = insert_stmt.on_conflict_do_update( - index_elements=db.block_type_unique_upsert_columns, + index_elements=db.orm.block_type_unique_upsert_columns, set_=insert_values, ) await session.execute(insert_stmt) query = ( - sa.select(BlockType) + sa.select(db.BlockType) .where( sa.and_( - BlockType.name == insert_values["name"], + db.BlockType.name == insert_values["name"], ) ) .execution_options(populate_existing=True) @@ -78,7 +77,9 @@ async def create_block_type( return result.scalar() +@db_injector async def read_block_type( + db: PrefectDBInterface, session: AsyncSession, block_type_id: UUID, ) -> Union[BlockType, None]: @@ -92,11 +93,12 @@ async def read_block_type( Returns: BlockType: an ORM block type model """ - return await session.get(BlockType, block_type_id) + return await session.get(db.BlockType, block_type_id) +@db_injector async def read_block_type_by_slug( - session: AsyncSession, block_type_slug: str + db: PrefectDBInterface, session: AsyncSession, block_type_slug: str ) -> Union[BlockType, None]: """ Reads a block type by slug. @@ -110,12 +112,14 @@ async def read_block_type_by_slug( """ result = await session.execute( - sa.select(BlockType).where(BlockType.slug == block_type_slug) + sa.select(db.BlockType).where(db.BlockType.slug == block_type_slug) ) return result.scalar() +@db_injector async def read_block_types( + db: PrefectDBInterface, session: AsyncSession, block_type_filter: Optional[schemas.filters.BlockTypeFilter] = None, block_schema_filter: Optional[schemas.filters.BlockSchemaFilter] = None, @@ -130,14 +134,14 @@ async def read_block_types( Returns: List[BlockType]: List of """ - query = sa.select(BlockType).order_by(BlockType.name) + query = sa.select(db.BlockType).order_by(db.BlockType.name) if block_type_filter is not None: query = query.where(block_type_filter.as_sql_filter()) if block_schema_filter is not None: - exists_clause = sa.select(BlockSchema).where( - BlockSchema.block_type_id == BlockType.id, + exists_clause = sa.select(db.BlockSchema).where( + db.BlockSchema.block_type_id == db.BlockType.id, block_schema_filter.as_sql_filter(), ) query = query.where(exists_clause.exists()) @@ -152,7 +156,9 @@ async def read_block_types( return result.scalars().unique().all() +@db_injector async def update_block_type( + db: PrefectDBInterface, session: AsyncSession, block_type_id: Union[str, UUID], block_type: Union[ @@ -186,15 +192,18 @@ async def update_block_type( ) update_statement = ( - sa.update(BlockType) - .where(BlockType.id == block_type_id) + sa.update(db.BlockType) + .where(db.BlockType.id == block_type_id) .values(**block_type.model_dump_for_orm(exclude_unset=True, exclude={"id"})) ) result = await session.execute(update_statement) return result.rowcount > 0 -async def delete_block_type(session: AsyncSession, block_type_id: str) -> bool: +@db_injector +async def delete_block_type( + db: PrefectDBInterface, session: AsyncSession, block_type_id: str +) -> bool: """ Delete a block type by id. @@ -207,6 +216,6 @@ async def delete_block_type(session: AsyncSession, block_type_id: str) -> bool: """ result = await session.execute( - sa.delete(BlockType).where(BlockType.id == block_type_id) + sa.delete(db.BlockType).where(db.BlockType.id == block_type_id) ) return result.rowcount > 0 diff --git a/src/prefect/server/models/concurrency_limits.py b/src/prefect/server/models/concurrency_limits.py index 1fed1a80c6583..af0fc63e05b1f 100644 --- a/src/prefect/server/models/concurrency_limits.py +++ b/src/prefect/server/models/concurrency_limits.py @@ -11,9 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.schemas as schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models @db_injector @@ -33,10 +31,10 @@ async def create_concurrency_limit( concurrency_limit.updated = pendulum.now("UTC") # type: ignore[assignment] insert_stmt = ( - db.insert(orm_models.ConcurrencyLimit) + db.queries.insert(db.ConcurrencyLimit) .values(**insert_values) .on_conflict_do_update( - index_elements=db.concurrency_limit_unique_upsert_columns, + index_elements=db.orm.concurrency_limit_unique_upsert_columns, set_=concurrency_limit.model_dump_for_orm( include={"concurrency_limit", "updated"} ), @@ -46,8 +44,8 @@ async def create_concurrency_limit( await session.execute(insert_stmt) query = ( - sa.select(orm_models.ConcurrencyLimit) - .where(orm_models.ConcurrencyLimit.tag == concurrency_tag) + sa.select(db.ConcurrencyLimit) + .where(db.ConcurrencyLimit.tag == concurrency_tag) .execution_options(populate_existing=True) ) @@ -55,7 +53,9 @@ async def create_concurrency_limit( return result.scalar_one() +@db_injector async def read_concurrency_limit( + db: PrefectDBInterface, session: AsyncSession, concurrency_limit_id: UUID, ) -> Union[orm_models.ConcurrencyLimit, None]: @@ -64,15 +64,17 @@ async def read_concurrency_limit( conditions might allow the concurrency limit to be temporarily exceeded. """ - query = sa.select(orm_models.ConcurrencyLimit).where( - orm_models.ConcurrencyLimit.id == concurrency_limit_id + query = sa.select(db.ConcurrencyLimit).where( + db.ConcurrencyLimit.id == concurrency_limit_id ) result = await session.execute(query) return result.scalar() +@db_injector async def read_concurrency_limit_by_tag( + db: PrefectDBInterface, session: AsyncSession, tag: str, ) -> Union[orm_models.ConcurrencyLimit, None]: @@ -81,15 +83,15 @@ async def read_concurrency_limit_by_tag( conditions might allow the concurrency limit to be temporarily exceeded. """ - query = sa.select(orm_models.ConcurrencyLimit).where( - orm_models.ConcurrencyLimit.tag == tag - ) + query = sa.select(db.ConcurrencyLimit).where(db.ConcurrencyLimit.tag == tag) result = await session.execute(query) return result.scalar() +@db_injector async def reset_concurrency_limit_by_tag( + db: PrefectDBInterface, session: AsyncSession, tag: str, slot_override: Optional[List[UUID]] = None, @@ -97,9 +99,7 @@ async def reset_concurrency_limit_by_tag( """ Resets a concurrency limit by tag. """ - query = sa.select(orm_models.ConcurrencyLimit).where( - orm_models.ConcurrencyLimit.tag == tag - ) + query = sa.select(db.ConcurrencyLimit).where(db.ConcurrencyLimit.tag == tag) result = await session.execute(query) concurrency_limit = result.scalar() if concurrency_limit: @@ -110,7 +110,9 @@ async def reset_concurrency_limit_by_tag( return concurrency_limit +@db_injector async def filter_concurrency_limits_for_orchestration( + db: PrefectDBInterface, session: AsyncSession, tags: List[str], ) -> Sequence[orm_models.ConcurrencyLimit]: @@ -121,40 +123,44 @@ async def filter_concurrency_limits_for_orchestration( """ query = ( - sa.select(orm_models.ConcurrencyLimit) - .filter(orm_models.ConcurrencyLimit.tag.in_(tags)) - .order_by(orm_models.ConcurrencyLimit.tag) + sa.select(db.ConcurrencyLimit) + .filter(db.ConcurrencyLimit.tag.in_(tags)) + .order_by(db.ConcurrencyLimit.tag) .with_for_update() ) result = await session.execute(query) return result.scalars().all() +@db_injector async def delete_concurrency_limit( + db: PrefectDBInterface, session: AsyncSession, concurrency_limit_id: UUID, ) -> bool: - query = sa.delete(orm_models.ConcurrencyLimit).where( - orm_models.ConcurrencyLimit.id == concurrency_limit_id + query = sa.delete(db.ConcurrencyLimit).where( + db.ConcurrencyLimit.id == concurrency_limit_id ) result = await session.execute(query) return result.rowcount > 0 +@db_injector async def delete_concurrency_limit_by_tag( + db: PrefectDBInterface, session: AsyncSession, tag: str, ) -> bool: - query = sa.delete(orm_models.ConcurrencyLimit).where( - orm_models.ConcurrencyLimit.tag == tag - ) + query = sa.delete(db.ConcurrencyLimit).where(db.ConcurrencyLimit.tag == tag) result = await session.execute(query) return result.rowcount > 0 +@db_injector async def read_concurrency_limits( + db: PrefectDBInterface, session: AsyncSession, limit: Optional[int] = None, offset: Optional[int] = None, @@ -172,9 +178,7 @@ async def read_concurrency_limits( List[orm_models.ConcurrencyLimit]: concurrency limits """ - query = sa.select(orm_models.ConcurrencyLimit).order_by( - orm_models.ConcurrencyLimit.tag - ) + query = sa.select(db.ConcurrencyLimit).order_by(db.ConcurrencyLimit.tag) if offset is not None: query = query.offset(offset) diff --git a/src/prefect/server/models/concurrency_limits_v2.py b/src/prefect/server/models/concurrency_limits_v2.py index 3c90e25c8b7b7..281a9e0eb3cd5 100644 --- a/src/prefect/server/models/concurrency_limits_v2.py +++ b/src/prefect/server/models/concurrency_limits_v2.py @@ -7,24 +7,22 @@ import prefect.server.schemas as schemas from prefect._internal.compatibility.deprecated import deprecated_parameter -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models -def active_slots_after_decay() -> ColumnElement[float]: +def active_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: # Active slots will decay at a rate of `slot_decay_per_second` per second. return sa.func.greatest( 0, - orm_models.ConcurrencyLimitV2.active_slots + db.ConcurrencyLimitV2.active_slots - sa.func.floor( - orm_models.ConcurrencyLimitV2.slot_decay_per_second - * sa.func.date_diff_seconds(orm_models.ConcurrencyLimitV2.updated) + db.ConcurrencyLimitV2.slot_decay_per_second + * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated) ), ) -def denied_slots_after_decay() -> ColumnElement[float]: +def denied_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: # Denied slots decay at a rate of `slot_decay_per_second` per second if it's # greater than 0, otherwise it decays at a rate of `avg_slot_occupancy_seconds`. # The combination of `denied_slots` and `slot_decay_per_second` / @@ -32,22 +30,22 @@ def denied_slots_after_decay() -> ColumnElement[float]: # when slots will be available again. return sa.func.greatest( 0, - orm_models.ConcurrencyLimitV2.denied_slots + db.ConcurrencyLimitV2.denied_slots - sa.func.floor( sa.case( ( - orm_models.ConcurrencyLimitV2.slot_decay_per_second > 0.0, - orm_models.ConcurrencyLimitV2.slot_decay_per_second, + db.ConcurrencyLimitV2.slot_decay_per_second > 0.0, + db.ConcurrencyLimitV2.slot_decay_per_second, ), else_=( 1.0 / sa.cast( - orm_models.ConcurrencyLimitV2.avg_slot_occupancy_seconds, + db.ConcurrencyLimitV2.avg_slot_occupancy_seconds, sa.Float, ) ), ) - * sa.func.date_diff_seconds(orm_models.ConcurrencyLimitV2.updated) + * sa.func.date_diff_seconds(db.ConcurrencyLimitV2.updated) ), ) @@ -61,13 +59,15 @@ def denied_slots_after_decay() -> ColumnElement[float]: MINIMUM_OCCUPANCY_SECONDS_PER_SLOT = 0.1 +@db_injector async def create_concurrency_limit( + db: PrefectDBInterface, session: AsyncSession, concurrency_limit: Union[ schemas.actions.ConcurrencyLimitV2Create, schemas.core.ConcurrencyLimitV2 ], ) -> orm_models.ConcurrencyLimitV2: - model = orm_models.ConcurrencyLimitV2(**concurrency_limit.model_dump()) + model = db.ConcurrencyLimitV2(**concurrency_limit.model_dump()) session.add(model) await session.flush() @@ -75,7 +75,9 @@ async def create_concurrency_limit( return model +@db_injector async def read_concurrency_limit( + db: PrefectDBInterface, session: AsyncSession, concurrency_limit_id: Optional[UUID] = None, name: Optional[str] = None, @@ -84,23 +86,23 @@ async def read_concurrency_limit( raise ValueError("Must provide either concurrency_limit_id or name") where = ( - orm_models.ConcurrencyLimitV2.id == concurrency_limit_id + db.ConcurrencyLimitV2.id == concurrency_limit_id if concurrency_limit_id - else orm_models.ConcurrencyLimitV2.name == name + else db.ConcurrencyLimitV2.name == name ) - query = sa.select(orm_models.ConcurrencyLimitV2).where(where) + query = sa.select(db.ConcurrencyLimitV2).where(where) result = await session.execute(query) return result.scalar() +@db_injector async def read_all_concurrency_limits( + db: PrefectDBInterface, session: AsyncSession, limit: int, offset: int, ) -> Sequence[orm_models.ConcurrencyLimitV2]: - query = sa.select(orm_models.ConcurrencyLimitV2).order_by( - orm_models.ConcurrencyLimitV2.name - ) + query = sa.select(db.ConcurrencyLimitV2).order_by(db.ConcurrencyLimitV2.name) if offset is not None: query = query.offset(offset) @@ -111,7 +113,9 @@ async def read_all_concurrency_limits( return result.scalars().unique().all() +@db_injector async def update_concurrency_limit( + db: PrefectDBInterface, session: AsyncSession, concurrency_limit: schemas.actions.ConcurrencyLimitV2Update, concurrency_limit_id: Optional[UUID] = None, @@ -127,13 +131,13 @@ async def update_concurrency_limit( raise ValueError("Must provide either concurrency_limit_id or name") where = ( - orm_models.ConcurrencyLimitV2.id == concurrency_limit_id + db.ConcurrencyLimitV2.id == concurrency_limit_id if concurrency_limit_id - else orm_models.ConcurrencyLimitV2.name == name + else db.ConcurrencyLimitV2.name == name ) result = await session.execute( - sa.update(orm_models.ConcurrencyLimitV2) + sa.update(db.ConcurrencyLimitV2) .where(where) .values(**concurrency_limit.model_dump(exclude_unset=True)) ) @@ -141,7 +145,9 @@ async def update_concurrency_limit( return result.rowcount > 0 +@db_injector async def delete_concurrency_limit( + db: PrefectDBInterface, session: AsyncSession, concurrency_limit_id: Optional[UUID] = None, name: Optional[str] = None, @@ -150,16 +156,17 @@ async def delete_concurrency_limit( raise ValueError("Must provide either concurrency_limit_id or name") where = ( - orm_models.ConcurrencyLimitV2.id == concurrency_limit_id + db.ConcurrencyLimitV2.id == concurrency_limit_id if concurrency_limit_id - else orm_models.ConcurrencyLimitV2.name == name + else db.ConcurrencyLimitV2.name == name ) - query = sa.delete(orm_models.ConcurrencyLimitV2).where(where) + query = sa.delete(db.ConcurrencyLimitV2).where(where) result = await session.execute(query) return result.rowcount > 0 +@db_injector @deprecated_parameter( name="create_if_missing", start_date="Sep 2024", @@ -168,13 +175,14 @@ async def delete_concurrency_limit( help="Limits must be explicitly created before acquiring concurrency slots.", ) async def bulk_read_or_create_concurrency_limits( + db: PrefectDBInterface, session: AsyncSession, names: List[str], create_if_missing: Optional[bool] = None, ) -> List[orm_models.ConcurrencyLimitV2]: # Get all existing concurrency limits in `names`. - existing_query = sa.select(orm_models.ConcurrencyLimitV2).where( - orm_models.ConcurrencyLimitV2.name.in_(names) + existing_query = sa.select(db.ConcurrencyLimitV2).where( + db.ConcurrencyLimitV2.name.in_(names) ) existing_limits = list((await session.execute(existing_query)).scalars().all()) @@ -184,7 +192,7 @@ async def bulk_read_or_create_concurrency_limits( if missing_names and create_if_missing: new_limits = [ - orm_models.ConcurrencyLimitV2( + db.ConcurrencyLimitV2( **schemas.core.ConcurrencyLimitV2( name=name, limit=1, active=False ).model_dump() @@ -206,16 +214,16 @@ async def bulk_increment_active_slots( concurrency_limit_ids: List[UUID], slots: int, ) -> bool: - active_slots = active_slots_after_decay() - denied_slots = denied_slots_after_decay() + active_slots = active_slots_after_decay(db) + denied_slots = denied_slots_after_decay(db) query = ( - sa.update(orm_models.ConcurrencyLimitV2) + sa.update(db.ConcurrencyLimitV2) .where( sa.and_( - orm_models.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), - orm_models.ConcurrencyLimitV2.active == True, # noqa - active_slots + slots <= orm_models.ConcurrencyLimitV2.limit, + db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), + db.ConcurrencyLimitV2.active == True, # noqa + active_slots + slots <= db.ConcurrencyLimitV2.limit, ) ) .values( @@ -237,19 +245,19 @@ async def bulk_decrement_active_slots( occupancy_seconds: Optional[float] = None, ) -> bool: query = ( - sa.update(orm_models.ConcurrencyLimitV2) + sa.update(db.ConcurrencyLimitV2) .where( sa.and_( - orm_models.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), - orm_models.ConcurrencyLimitV2.active == True, # noqa + db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), + db.ConcurrencyLimitV2.active == True, # noqa ) ) .values( active_slots=sa.case( - (active_slots_after_decay() - slots < 0, 0), - else_=active_slots_after_decay() - slots, + (active_slots_after_decay(db) - slots < 0, 0), + else_=active_slots_after_decay(db) - slots, ), - denied_slots=denied_slots_after_decay(), + denied_slots=denied_slots_after_decay(db), ) ) @@ -261,14 +269,14 @@ async def bulk_decrement_active_slots( query = query.values( # Update the average occupancy seconds per slot as a weighted # average over the last `limit * OCCUPANCY_SAMPLE_MULTIPLIER` samples. - avg_slot_occupancy_seconds=orm_models.ConcurrencyLimitV2.avg_slot_occupancy_seconds + avg_slot_occupancy_seconds=db.ConcurrencyLimitV2.avg_slot_occupancy_seconds + ( occupancy_seconds_per_slot - / (orm_models.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) + / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) ) - ( - orm_models.ConcurrencyLimitV2.avg_slot_occupancy_seconds - / (orm_models.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) + db.ConcurrencyLimitV2.avg_slot_occupancy_seconds + / (db.ConcurrencyLimitV2.limit * OCCUPANCY_SAMPLES_MULTIPLIER) ), ) @@ -284,14 +292,14 @@ async def bulk_update_denied_slots( slots: int, ) -> bool: query = ( - sa.update(orm_models.ConcurrencyLimitV2) + sa.update(db.ConcurrencyLimitV2) .where( sa.and_( - orm_models.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), - orm_models.ConcurrencyLimitV2.active == True, # noqa + db.ConcurrencyLimitV2.id.in_(concurrency_limit_ids), + db.ConcurrencyLimitV2.active == True, # noqa ) ) - .values(denied_slots=denied_slots_after_decay() + slots) + .values(denied_slots=denied_slots_after_decay(db) + slots) ) result = await session.execute(query) diff --git a/src/prefect/server/models/configuration.py b/src/prefect/server/models/configuration.py index e2b2704567f3a..681cbdcf39e8b 100644 --- a/src/prefect/server/models/configuration.py +++ b/src/prefect/server/models/configuration.py @@ -4,9 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models @db_injector @@ -16,9 +14,7 @@ async def write_configuration( configuration: schemas.core.Configuration, ) -> orm_models.Configuration: # first see if the key already exists - query = sa.select(orm_models.Configuration).where( - orm_models.Configuration.key == configuration.key - ) + query = sa.select(db.Configuration).where(db.Configuration.key == configuration.key) result = await session.execute(query) # type: ignore existing_configuration = result.scalar() # if it exists, update its value @@ -26,14 +22,14 @@ async def write_configuration( existing_configuration.value = configuration.value # else create a new ORM object else: - existing_configuration = orm_models.Configuration( + existing_configuration = db.Configuration( key=configuration.key, value=configuration.value ) session.add(existing_configuration) await session.flush() # clear the cache for this key after writing a value - db.clear_configuration_value_cache_for_key(key=configuration.key) + db.queries.clear_configuration_value_cache_for_key(key=configuration.key) return existing_configuration @@ -44,7 +40,7 @@ async def read_configuration( session: AsyncSession, key: str, ) -> Optional[schemas.core.Configuration]: - value = await db.read_configuration_value(session=session, key=key) + value = await db.queries.read_configuration_value(session=session, key=key) return ( schemas.core.Configuration(key=key, value=value) if value is not None else None ) diff --git a/src/prefect/server/models/csrf_token.py b/src/prefect/server/models/csrf_token.py index e22c1956694f8..c9d3ab50099fa 100644 --- a/src/prefect/server/models/csrf_token.py +++ b/src/prefect/server/models/csrf_token.py @@ -6,9 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect import settings -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.schemas import core @@ -36,14 +34,14 @@ async def create_or_update_csrf_token( token = secrets.token_hex(32) await session.execute( - db.insert(orm_models.CsrfToken) + db.queries.insert(db.CsrfToken) .values( client=client, token=token, expiration=expiration, ) .on_conflict_do_update( - index_elements=[orm_models.CsrfToken.client], + index_elements=[db.CsrfToken.client], set_={"token": token, "expiration": expiration}, ), ) @@ -55,7 +53,9 @@ async def create_or_update_csrf_token( return csrf_token +@db_injector async def read_token_for_client( + db: PrefectDBInterface, session: AsyncSession, client: str, ) -> Optional[core.CsrfToken]: @@ -71,10 +71,10 @@ async def read_token_for_client( """ token = ( await session.execute( - sa.select(orm_models.CsrfToken).where( + sa.select(db.CsrfToken).where( sa.and_( - orm_models.CsrfToken.expiration > datetime.now(timezone.utc), - orm_models.CsrfToken.client == client, + db.CsrfToken.expiration > datetime.now(timezone.utc), + db.CsrfToken.client == client, ) ) ) @@ -86,7 +86,8 @@ async def read_token_for_client( return core.CsrfToken.model_validate(token, from_attributes=True) -async def delete_expired_tokens(session: AsyncSession) -> int: +@db_injector +async def delete_expired_tokens(db: PrefectDBInterface, session: AsyncSession) -> int: """Delete expired CSRF tokens. Args: @@ -97,8 +98,8 @@ async def delete_expired_tokens(session: AsyncSession) -> int: """ result = await session.execute( - sa.delete(orm_models.CsrfToken).where( - orm_models.CsrfToken.expiration < datetime.now(timezone.utc) + sa.delete(db.CsrfToken).where( + db.CsrfToken.expiration < datetime.now(timezone.utc) ) ) return result.rowcount diff --git a/src/prefect/server/models/deployments.py b/src/prefect/server/models/deployments.py index 87ae7a1e7aff5..e93846ec84f73 100644 --- a/src/prefect/server/models/deployments.py +++ b/src/prefect/server/models/deployments.py @@ -4,7 +4,8 @@ """ import datetime -from typing import Any, Dict, Iterable, List, Optional, Sequence, TypeVar, cast +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast from uuid import UUID, uuid4 import pendulum @@ -15,9 +16,7 @@ from sqlalchemy.sql import Select from prefect.server import models, schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.events.clients import PrefectServerEventsClient from prefect.server.exceptions import ObjectNotFoundError from prefect.server.models.events import deployment_status_event @@ -29,10 +28,12 @@ PREFECT_API_SERVICES_SCHEDULER_MIN_SCHEDULED_TIME, ) -T = TypeVar("T", bound=tuple) +T = TypeVar("T", bound=tuple[Any, ...]) +@db_injector async def _delete_scheduled_runs( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, auto_scheduled_only: bool = False, @@ -46,15 +47,15 @@ async def _delete_scheduled_runs( deployment_id: the deployment for which we should delete runs. auto_scheduled_only: if True, only delete auto scheduled runs. Defaults to `False`. """ - delete_query = sa.delete(orm_models.FlowRun).where( - orm_models.FlowRun.deployment_id == deployment_id, - orm_models.FlowRun.state_type == schemas.states.StateType.SCHEDULED.value, - orm_models.FlowRun.run_count == 0, + delete_query = sa.delete(db.FlowRun).where( + db.FlowRun.deployment_id == deployment_id, + db.FlowRun.state_type == schemas.states.StateType.SCHEDULED.value, + db.FlowRun.run_count == 0, ) if auto_scheduled_only: delete_query = delete_query.where( - orm_models.FlowRun.auto_scheduled.is_(True), + db.FlowRun.auto_scheduled.is_(True), ) await session.execute(delete_query) @@ -112,10 +113,10 @@ async def create_deployment( conflict_update_fields["infra_overrides"] = job_variables insert_stmt = ( - db.insert(orm_models.Deployment) + db.queries.insert(db.Deployment) .values(**insert_values) .on_conflict_do_update( - index_elements=db.deployment_unique_upsert_columns, + index_elements=db.orm.deployment_unique_upsert_columns, set_={**conflict_update_fields}, ) ) @@ -124,10 +125,10 @@ async def create_deployment( # Get the id of the deployment we just created or updated result = await session.execute( - sa.select(orm_models.Deployment.id).where( + sa.select(db.Deployment.id).where( sa.and_( - orm_models.Deployment.flow_id == deployment.flow_id, - orm_models.Deployment.name == deployment.name, + db.Deployment.flow_id == deployment.flow_id, + db.Deployment.name == deployment.name, ) ) ) @@ -160,15 +161,15 @@ async def create_deployment( if requested_concurrency_limit != "unset": await _create_or_update_deployment_concurrency_limit( - session, deployment_id, deployment.concurrency_limit + db, session, deployment_id, deployment.concurrency_limit ) query = ( - sa.select(orm_models.Deployment) + sa.select(db.Deployment) .where( sa.and_( - orm_models.Deployment.flow_id == deployment.flow_id, - orm_models.Deployment.name == deployment.name, + db.Deployment.flow_id == deployment.flow_id, + db.Deployment.name == deployment.name, ) ) .execution_options(populate_existing=True) @@ -177,7 +178,9 @@ async def create_deployment( return refreshed_result.scalar() +@db_injector async def update_deployment( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, deployment: schemas.actions.DeploymentUpdate, @@ -244,8 +247,8 @@ async def update_deployment( update_data["work_queue_id"] = work_queue.id update_stmt = ( - sa.update(orm_models.Deployment) - .where(orm_models.Deployment.id == deployment_id) + sa.update(db.Deployment) + .where(db.Deployment.id == deployment_id) .values(**update_data) ) result = await session.execute(update_stmt) @@ -275,16 +278,19 @@ async def update_deployment( if requested_concurrency_limit_update != "unset": await _create_or_update_deployment_concurrency_limit( - session, deployment_id, deployment.concurrency_limit + db, session, deployment_id, deployment.concurrency_limit ) return result.rowcount > 0 async def _create_or_update_deployment_concurrency_limit( - session: AsyncSession, deployment_id: UUID, limit: Optional[int] + db: PrefectDBInterface, + session: AsyncSession, + deployment_id: UUID, + limit: Optional[int], ): - deployment = await session.get(orm_models.Deployment, deployment_id) + deployment = await session.get(db.Deployment, deployment_id) assert deployment is not None if ( @@ -296,21 +302,22 @@ async def _create_or_update_deployment_concurrency_limit( deployment._concurrency_limit = limit if limit is None: await _delete_related_concurrency_limit( - session=session, deployment_id=deployment_id + db, session=session, deployment_id=deployment_id ) await session.refresh(deployment) elif deployment.global_concurrency_limit: deployment.global_concurrency_limit.limit = limit else: limit_name = f"deployment:{deployment_id}" - new_limit = orm_models.ConcurrencyLimitV2(name=limit_name, limit=limit) + new_limit = db.ConcurrencyLimitV2(name=limit_name, limit=limit) deployment.global_concurrency_limit = new_limit session.add(deployment) +@db_injector async def read_deployment( - session: AsyncSession, deployment_id: UUID + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID ) -> Optional[orm_models.Deployment]: """Reads a deployment by id. @@ -322,11 +329,12 @@ async def read_deployment( orm_models.Deployment: the deployment """ - return await session.get(orm_models.Deployment, deployment_id) + return await session.get(db.Deployment, deployment_id) +@db_injector async def read_deployment_by_name( - session: AsyncSession, name: str, flow_name: str + db: PrefectDBInterface, session: AsyncSession, name: str, flow_name: str ) -> Optional[orm_models.Deployment]: """Reads a deployment by name. @@ -340,12 +348,12 @@ async def read_deployment_by_name( """ result = await session.execute( - select(orm_models.Deployment) - .join(orm_models.Flow, orm_models.Deployment.flow_id == orm_models.Flow.id) + select(db.Deployment) + .join(db.Flow, db.Deployment.flow_id == db.Flow.id) .where( sa.and_( - orm_models.Flow.name == flow_name, - orm_models.Deployment.name == name, + db.Flow.name == flow_name, + db.Deployment.name == name, ) ) .limit(1) @@ -354,6 +362,7 @@ async def read_deployment_by_name( async def _apply_deployment_filters( + db: PrefectDBInterface, query: Select[T], flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -370,16 +379,16 @@ async def _apply_deployment_filters( query = query.where(deployment_filter.as_sql_filter()) if flow_filter: - flow_exists_clause = select(orm_models.Deployment.id).where( - orm_models.Deployment.flow_id == orm_models.Flow.id, + flow_exists_clause = select(db.Deployment.id).where( + db.Deployment.flow_id == db.Flow.id, flow_filter.as_sql_filter(), ) query = query.where(flow_exists_clause.exists()) if flow_run_filter or task_run_filter: - flow_run_exists_clause = select(orm_models.FlowRun).where( - orm_models.Deployment.id == orm_models.FlowRun.deployment_id + flow_run_exists_clause = select(db.FlowRun).where( + db.Deployment.id == db.FlowRun.deployment_id ) if flow_run_filter: @@ -388,15 +397,15 @@ async def _apply_deployment_filters( ) if task_run_filter: flow_run_exists_clause = flow_run_exists_clause.join( - orm_models.TaskRun, - orm_models.TaskRun.flow_run_id == orm_models.FlowRun.id, + db.TaskRun, + db.TaskRun.flow_run_id == db.FlowRun.id, ).where(task_run_filter.as_sql_filter()) query = query.where(flow_run_exists_clause.exists()) if work_pool_filter or work_queue_filter: - work_pool_exists_clause = select(orm_models.WorkQueue).where( - orm_models.Deployment.work_queue_id == orm_models.WorkQueue.id + work_pool_exists_clause = select(db.WorkQueue).where( + db.Deployment.work_queue_id == db.WorkQueue.id ) if work_queue_filter: @@ -406,8 +415,8 @@ async def _apply_deployment_filters( if work_pool_filter: work_pool_exists_clause = work_pool_exists_clause.join( - orm_models.WorkPool, - orm_models.WorkPool.id == orm_models.WorkQueue.work_pool_id, + db.WorkPool, + db.WorkPool.id == db.WorkQueue.work_pool_id, ).where(work_pool_filter.as_sql_filter()) query = query.where(work_pool_exists_clause.exists()) @@ -415,7 +424,9 @@ async def _apply_deployment_filters( return query +@db_injector async def read_deployments( + db: PrefectDBInterface, session: AsyncSession, offset: Optional[int] = None, limit: Optional[int] = None, @@ -443,12 +454,13 @@ async def read_deployments( sort: the sort criteria for selected deployments. Defaults to `name` ASC. Returns: - List[orm_models.Deployment]: deployments + list[orm_models.Deployment]: deployments """ - query = select(orm_models.Deployment).order_by(sort.as_sql_sort()) + query = select(db.Deployment).order_by(*sort.as_sql_sort()) query = await _apply_deployment_filters( + db, query=query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -467,7 +479,9 @@ async def read_deployments( return result.scalars().unique().all() +@db_injector async def count_deployments( + db: PrefectDBInterface, session: AsyncSession, flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -492,9 +506,10 @@ async def count_deployments( int: the number of deployments matching filters """ - query = select(sa.func.count(sa.text("*"))).select_from(orm_models.Deployment) + query = select(sa.func.count(None)).select_from(db.Deployment) query = await _apply_deployment_filters( + db, query=query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -508,7 +523,10 @@ async def count_deployments( return result.scalar_one() -async def delete_deployment(session: AsyncSession, deployment_id: UUID) -> bool: +@db_injector +async def delete_deployment( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID +) -> bool: """ Delete a deployment by id. @@ -526,27 +544,31 @@ async def delete_deployment(session: AsyncSession, deployment_id: UUID) -> bool: ) await _delete_related_concurrency_limit( - session=session, deployment_id=deployment_id + db, session=session, deployment_id=deployment_id ) result = await session.execute( - delete(orm_models.Deployment).where(orm_models.Deployment.id == deployment_id) + delete(db.Deployment).where(db.Deployment.id == deployment_id) ) return result.rowcount > 0 -async def _delete_related_concurrency_limit(session: AsyncSession, deployment_id: UUID): +async def _delete_related_concurrency_limit( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID +): return await session.execute( - delete(orm_models.ConcurrencyLimitV2).where( - orm_models.ConcurrencyLimitV2.id - == sa.select(orm_models.Deployment.concurrency_limit_id) - .where(orm_models.Deployment.id == deployment_id) + delete(db.ConcurrencyLimitV2).where( + db.ConcurrencyLimitV2.id + == sa.select(db.Deployment.concurrency_limit_id) + .where(db.Deployment.id == deployment_id) .scalar_subquery() ) ) +@db_injector async def schedule_runs( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, start_time: Optional[datetime.datetime] = None, @@ -583,8 +605,10 @@ async def schedule_runs( """ if min_runs is None: min_runs = PREFECT_API_SERVICES_SCHEDULER_MIN_RUNS.value() + assert min_runs is not None if max_runs is None: max_runs = PREFECT_API_SERVICES_SCHEDULER_MAX_RUNS.value() + assert max_runs is not None if start_time is None: start_time = pendulum.now("UTC") if end_time is None: @@ -593,11 +617,15 @@ async def schedule_runs( ) if min_time is None: min_time = PREFECT_API_SERVICES_SCHEDULER_MIN_SCHEDULED_TIME.value() + assert min_time is not None actual_start_time = pendulum.instance(start_time) + if TYPE_CHECKING: + assert end_time is not None actual_end_time = pendulum.instance(end_time) runs = await _generate_scheduled_flow_runs( + db, session=session, deployment_id=deployment_id, start_time=actual_start_time, @@ -611,6 +639,7 @@ async def schedule_runs( async def _generate_scheduled_flow_runs( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, start_time: datetime.datetime, @@ -619,7 +648,7 @@ async def _generate_scheduled_flow_runs( min_runs: int, max_runs: int, auto_scheduled: bool = True, -) -> List[Dict]: +) -> list[dict[str, Any]]: """ Given a `deployment_id` and schedule, generates a list of flow run objects and associated scheduled states that represent scheduled flow runs. This method @@ -651,9 +680,9 @@ async def _generate_scheduled_flow_runs( Returns: a list of dictionary representations of the `FlowRun` objects to schedule """ - runs = [] + runs: list[dict[str, Any]] = [] - deployment = await session.get(orm_models.Deployment, deployment_id) + deployment = await session.get(db.Deployment, deployment_id) if not deployment: return [] @@ -667,7 +696,7 @@ async def _generate_scheduled_flow_runs( ) for deployment_schedule in active_deployment_schedules: - dates = [] + dates: list[pendulum.DateTime] = [] # generate up to `n` dates satisfying the min of `max_runs` and `end_time` for dt in deployment_schedule.schedule._get_dates_generator( @@ -713,7 +742,7 @@ async def _generate_scheduled_flow_runs( @db_injector async def _insert_scheduled_flow_runs( - db: PrefectDBInterface, session: AsyncSession, runs: List[Dict] + db: PrefectDBInterface, session: AsyncSession, runs: list[dict[str, Any]] ) -> Sequence[UUID]: """ Given a list of flow runs to schedule, as generated by `_generate_scheduled_flow_runs`, @@ -735,8 +764,8 @@ async def _insert_scheduled_flow_runs( # this syntax (insert statement, values to insert) is most efficient # because it uses a single bind parameter await session.execute( - db.insert(orm_models.FlowRun).on_conflict_do_nothing( - index_elements=db.flow_run_unique_upsert_columns + db.queries.insert(db.FlowRun).on_conflict_do_nothing( + index_elements=db.orm.flow_run_unique_upsert_columns ), runs, ) @@ -752,7 +781,7 @@ async def _insert_scheduled_flow_runs( inserted_flow_run_ids = (await session.execute(inserted_rows)).scalars().all() # insert flow run states that correspond to the newly-insert rows - insert_flow_run_states = [ + insert_flow_run_states: list[dict[str, Any]] = [ {"id": uuid4(), "flow_run_id": r["id"], **r["state"]} for r in runs if r["id"] in inserted_flow_run_ids @@ -761,12 +790,12 @@ async def _insert_scheduled_flow_runs( # this syntax (insert statement, values to insert) is most efficient # because it uses a single bind parameter await session.execute( - orm_models.FlowRunState.__table__.insert(), # type: ignore[attr-defined] + db.FlowRunState.__table__.insert(), # type: ignore[attr-defined] insert_flow_run_states, ) # set the `state_id` on the newly inserted runs - stmt = db.set_state_id_on_inserted_flow_runs_statement( + stmt = db.queries.set_state_id_on_inserted_flow_runs_statement( inserted_flow_run_ids=inserted_flow_run_ids, insert_flow_run_states=insert_flow_run_states, ) @@ -776,8 +805,9 @@ async def _insert_scheduled_flow_runs( return inserted_flow_run_ids +@db_injector async def check_work_queues_for_deployment( - session: AsyncSession, deployment_id: UUID + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID ) -> Sequence[orm_models.WorkQueue]: """ Get work queues that can pick up the specified deployment. @@ -801,7 +831,7 @@ async def check_work_queues_for_deployment( Returns: List[orm_models.WorkQueue]: WorkQueues """ - deployment = await session.get(orm_models.Deployment, deployment_id) + deployment = await session.get(db.Deployment, deployment_id) if not deployment: raise ObjectNotFoundError(f"Deployment with id {deployment_id} not found") @@ -809,24 +839,24 @@ def json_contains(a: Any, b: Any) -> sa.ColumnElement[bool]: return sa.type_coerce(a, type_=JSONB).contains(sa.type_coerce(b, type_=JSONB)) query = ( - select(orm_models.WorkQueue) + select(db.WorkQueue) # work queue tags are a subset of deployment tags .filter( or_( - json_contains(deployment.tags, orm_models.WorkQueue.filter["tags"]), - json_contains([], orm_models.WorkQueue.filter["tags"]), - json_contains(None, orm_models.WorkQueue.filter["tags"]), + json_contains(deployment.tags, db.WorkQueue.filter["tags"]), + json_contains([], db.WorkQueue.filter["tags"]), + json_contains(None, db.WorkQueue.filter["tags"]), ) ) # deployment_ids is null or contains the deployment's ID .filter( or_( json_contains( - orm_models.WorkQueue.filter["deployment_ids"], + db.WorkQueue.filter["deployment_ids"], str(deployment.id), ), - json_contains(None, orm_models.WorkQueue.filter["deployment_ids"]), - json_contains([], orm_models.WorkQueue.filter["deployment_ids"]), + json_contains(None, db.WorkQueue.filter["deployment_ids"]), + json_contains([], db.WorkQueue.filter["deployment_ids"]), ) ) ) @@ -835,11 +865,13 @@ def json_contains(a: Any, b: Any) -> sa.ColumnElement[bool]: return result.scalars().unique().all() +@db_injector async def create_deployment_schedules( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, - schedules: List[schemas.actions.DeploymentScheduleCreate], -) -> List[schemas.core.DeploymentSchedule]: + schedules: list[schemas.actions.DeploymentScheduleCreate], +) -> list[schemas.core.DeploymentSchedule]: """ Creates a deployment's schedules. @@ -849,15 +881,14 @@ async def create_deployment_schedules( schedules: a list of deployment schedule create actions """ - schedules_with_deployment_id = [] + schedules_with_deployment_id: list[dict[str, Any]] = [] for schedule in schedules: data = schedule.model_dump() data["deployment_id"] = deployment_id schedules_with_deployment_id.append(data) models = [ - orm_models.DeploymentSchedule(**schedule) - for schedule in schedules_with_deployment_id + db.DeploymentSchedule(**schedule) for schedule in schedules_with_deployment_id ] session.add_all(models) await session.flush() @@ -868,13 +899,15 @@ async def create_deployment_schedules( ] +@db_injector async def read_deployment_schedules( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, deployment_schedule_filter: Optional[ schemas.filters.DeploymentScheduleFilter ] = None, -) -> List[schemas.core.DeploymentSchedule]: +) -> list[schemas.core.DeploymentSchedule]: """ Reads a deployment's schedules. @@ -887,9 +920,9 @@ async def read_deployment_schedules( """ query = ( - sa.select(orm_models.DeploymentSchedule) - .where(orm_models.DeploymentSchedule.deployment_id == deployment_id) - .order_by(orm_models.DeploymentSchedule.updated.desc()) + sa.select(db.DeploymentSchedule) + .where(db.DeploymentSchedule.deployment_id == deployment_id) + .order_by(db.DeploymentSchedule.updated.desc()) ) if deployment_schedule_filter: @@ -903,7 +936,9 @@ async def read_deployment_schedules( ] +@db_injector async def update_deployment_schedule( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, deployment_schedule_id: UUID, @@ -919,11 +954,11 @@ async def update_deployment_schedule( """ result = await session.execute( - sa.update(orm_models.DeploymentSchedule) + sa.update(db.DeploymentSchedule) .where( sa.and_( - orm_models.DeploymentSchedule.id == deployment_schedule_id, - orm_models.DeploymentSchedule.deployment_id == deployment_id, + db.DeploymentSchedule.id == deployment_schedule_id, + db.DeploymentSchedule.deployment_id == deployment_id, ) ) .values(**schedule.model_dump(exclude_none=True)) @@ -932,8 +967,9 @@ async def update_deployment_schedule( return result.rowcount > 0 +@db_injector async def delete_schedules_for_deployment( - session: AsyncSession, deployment_id: UUID + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID ) -> bool: """ Deletes a deployment schedule. @@ -944,15 +980,17 @@ async def delete_schedules_for_deployment( """ result = await session.execute( - sa.delete(orm_models.DeploymentSchedule).where( - orm_models.DeploymentSchedule.deployment_id == deployment_id + sa.delete(db.DeploymentSchedule).where( + db.DeploymentSchedule.deployment_id == deployment_id ) ) return result.rowcount > 0 +@db_injector async def delete_deployment_schedule( + db: PrefectDBInterface, session: AsyncSession, deployment_id: UUID, deployment_schedule_id: UUID, @@ -966,10 +1004,10 @@ async def delete_deployment_schedule( """ result = await session.execute( - sa.delete(orm_models.DeploymentSchedule).where( + sa.delete(db.DeploymentSchedule).where( sa.and_( - orm_models.DeploymentSchedule.id == deployment_schedule_id, - orm_models.DeploymentSchedule.deployment_id == deployment_id, + db.DeploymentSchedule.id == deployment_schedule_id, + db.DeploymentSchedule.deployment_id == deployment_id, ) ) ) @@ -993,12 +1031,12 @@ async def mark_deployments_ready( begin_transaction=True, ) as session: result = await session.execute( - select(orm_models.Deployment.id).where( + select(db.Deployment.id).where( sa.or_( - orm_models.Deployment.id.in_(deployment_ids), - orm_models.Deployment.work_queue_id.in_(work_queue_ids), + db.Deployment.id.in_(deployment_ids), + db.Deployment.work_queue_id.in_(work_queue_ids), ), - orm_models.Deployment.status == DeploymentStatus.NOT_READY, + db.Deployment.status == DeploymentStatus.NOT_READY, ) ) unready_deployments = list(result.scalars().unique().all()) @@ -1006,11 +1044,11 @@ async def mark_deployments_ready( last_polled = pendulum.now("UTC") await session.execute( - sa.update(orm_models.Deployment) + sa.update(db.Deployment) .where( sa.or_( - orm_models.Deployment.id.in_(deployment_ids), - orm_models.Deployment.work_queue_id.in_(work_queue_ids), + db.Deployment.id.in_(deployment_ids), + db.Deployment.work_queue_id.in_(work_queue_ids), ) ) .values(status=DeploymentStatus.READY, last_polled=last_polled) @@ -1047,22 +1085,22 @@ async def mark_deployments_not_ready( begin_transaction=True, ) as session: result = await session.execute( - select(orm_models.Deployment.id).where( + select(db.Deployment.id).where( sa.or_( - orm_models.Deployment.id.in_(deployment_ids), - orm_models.Deployment.work_queue_id.in_(work_queue_ids), + db.Deployment.id.in_(deployment_ids), + db.Deployment.work_queue_id.in_(work_queue_ids), ), - orm_models.Deployment.status == DeploymentStatus.READY, + db.Deployment.status == DeploymentStatus.READY, ) ) ready_deployments = list(result.scalars().unique().all()) await session.execute( - sa.update(orm_models.Deployment) + sa.update(db.Deployment) .where( sa.or_( - orm_models.Deployment.id.in_(deployment_ids), - orm_models.Deployment.work_queue_id.in_(work_queue_ids), + db.Deployment.id.in_(deployment_ids), + db.Deployment.work_queue_id.in_(work_queue_ids), ) ) .values(status=DeploymentStatus.NOT_READY) diff --git a/src/prefect/server/models/flow_run_input.py b/src/prefect/server/models/flow_run_input.py index 75c37f2642c04..c872664200743 100644 --- a/src/prefect/server/models/flow_run_input.py +++ b/src/prefect/server/models/flow_run_input.py @@ -5,21 +5,25 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.schemas as schemas -from prefect.server.database import orm_models +from prefect.server.database import PrefectDBInterface, db_injector +@db_injector async def create_flow_run_input( + db: PrefectDBInterface, session: AsyncSession, flow_run_input: schemas.core.FlowRunInput, ) -> schemas.core.FlowRunInput: - model = orm_models.FlowRunInput(**flow_run_input.model_dump()) + model = db.FlowRunInput(**flow_run_input.model_dump()) session.add(model) await session.flush() return schemas.core.FlowRunInput.model_validate(model, from_attributes=True) +@db_injector async def filter_flow_run_input( + db: PrefectDBInterface, session: AsyncSession, flow_run_id: uuid.UUID, prefix: str, @@ -27,15 +31,15 @@ async def filter_flow_run_input( exclude_keys: List[str], ) -> List[schemas.core.FlowRunInput]: query = ( - sa.select(orm_models.FlowRunInput) + sa.select(db.FlowRunInput) .where( sa.and_( - orm_models.FlowRunInput.flow_run_id == flow_run_id, - orm_models.FlowRunInput.key.like(prefix + "%"), - orm_models.FlowRunInput.key.not_in(exclude_keys), + db.FlowRunInput.flow_run_id == flow_run_id, + db.FlowRunInput.key.like(prefix + "%"), + db.FlowRunInput.key.not_in(exclude_keys), ) ) - .order_by(orm_models.FlowRunInput.created) + .order_by(db.FlowRunInput.created) .limit(limit) ) @@ -46,15 +50,17 @@ async def filter_flow_run_input( ] +@db_injector async def read_flow_run_input( + db: PrefectDBInterface, session: AsyncSession, flow_run_id: uuid.UUID, key: str, ) -> Optional[schemas.core.FlowRunInput]: - query = sa.select(orm_models.FlowRunInput).where( + query = sa.select(db.FlowRunInput).where( sa.and_( - orm_models.FlowRunInput.flow_run_id == flow_run_id, - orm_models.FlowRunInput.key == key, + db.FlowRunInput.flow_run_id == flow_run_id, + db.FlowRunInput.key == key, ) ) @@ -66,16 +72,18 @@ async def read_flow_run_input( return None +@db_injector async def delete_flow_run_input( + db: PrefectDBInterface, session: AsyncSession, flow_run_id: uuid.UUID, key: str, ) -> bool: result = await session.execute( - sa.delete(orm_models.FlowRunInput).where( + sa.delete(db.FlowRunInput).where( sa.and_( - orm_models.FlowRunInput.flow_run_id == flow_run_id, - orm_models.FlowRunInput.key == key, + db.FlowRunInput.flow_run_id == flow_run_id, + db.FlowRunInput.key == key, ) ) ) diff --git a/src/prefect/server/models/flow_run_notification_policies.py b/src/prefect/server/models/flow_run_notification_policies.py index d5f62e05e4342..352abe1e090c3 100644 --- a/src/prefect/server/models/flow_run_notification_policies.py +++ b/src/prefect/server/models/flow_run_notification_policies.py @@ -12,9 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.schemas as schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models DEFAULT_MESSAGE_TEMPLATE = textwrap.dedent( """ @@ -28,7 +26,9 @@ ) +@db_injector async def create_flow_run_notification_policy( + db: PrefectDBInterface, session: AsyncSession, flow_run_notification_policy: schemas.core.FlowRunNotificationPolicy, ) -> orm_models.FlowRunNotificationPolicy: @@ -43,16 +43,16 @@ async def create_flow_run_notification_policy( orm_models.FlowRunNotificationPolicy: the newly-created FlowRunNotificationPolicy """ - model = orm_models.FlowRunNotificationPolicy( - **flow_run_notification_policy.model_dump() - ) + model = db.FlowRunNotificationPolicy(**flow_run_notification_policy.model_dump()) session.add(model) await session.flush() return model +@db_injector async def read_flow_run_notification_policy( + db: PrefectDBInterface, session: AsyncSession, flow_run_notification_policy_id: UUID, ) -> Union[orm_models.FlowRunNotificationPolicy, None]: @@ -68,11 +68,13 @@ async def read_flow_run_notification_policy( """ return await session.get( - orm_models.FlowRunNotificationPolicy, flow_run_notification_policy_id + db.FlowRunNotificationPolicy, flow_run_notification_policy_id ) +@db_injector async def read_flow_run_notification_policies( + db: PrefectDBInterface, session: AsyncSession, flow_run_notification_policy_filter: Optional[ schemas.filters.FlowRunNotificationPolicyFilter @@ -92,8 +94,8 @@ async def read_flow_run_notification_policies( List[db.FlowRunNotificationPolicy]: Notification policies """ - query = select(orm_models.FlowRunNotificationPolicy).order_by( - orm_models.FlowRunNotificationPolicy.id + query = select(db.FlowRunNotificationPolicy).order_by( + db.FlowRunNotificationPolicy.id ) if flow_run_notification_policy_filter: @@ -108,7 +110,9 @@ async def read_flow_run_notification_policies( return result.scalars().unique().all() +@db_injector async def update_flow_run_notification_policy( + db: PrefectDBInterface, session: AsyncSession, flow_run_notification_policy_id: UUID, flow_run_notification_policy: schemas.actions.FlowRunNotificationPolicyUpdate, @@ -129,17 +133,17 @@ async def update_flow_run_notification_policy( update_data = flow_run_notification_policy.model_dump_for_orm(exclude_unset=True) update_stmt = ( - sa.update(orm_models.FlowRunNotificationPolicy) - .where( - orm_models.FlowRunNotificationPolicy.id == flow_run_notification_policy_id - ) + sa.update(db.FlowRunNotificationPolicy) + .where(db.FlowRunNotificationPolicy.id == flow_run_notification_policy_id) .values(**update_data) ) result = await session.execute(update_stmt) return result.rowcount > 0 +@db_injector async def delete_flow_run_notification_policy( + db: PrefectDBInterface, session: AsyncSession, flow_run_notification_policy_id: UUID, ) -> bool: @@ -155,8 +159,8 @@ async def delete_flow_run_notification_policy( """ result = await session.execute( - delete(orm_models.FlowRunNotificationPolicy).where( - orm_models.FlowRunNotificationPolicy.id == flow_run_notification_policy_id + delete(db.FlowRunNotificationPolicy).where( + db.FlowRunNotificationPolicy.id == flow_run_notification_policy_id ) ) return result.rowcount > 0 @@ -168,6 +172,4 @@ async def queue_flow_run_notifications( session: AsyncSession, flow_run: Union[schemas.core.FlowRun, orm_models.FlowRun], ) -> None: - await db.queries.queue_flow_run_notifications( - session=session, flow_run=flow_run, db=db - ) + await db.queries.queue_flow_run_notifications(session=session, flow_run=flow_run) diff --git a/src/prefect/server/models/flow_run_states.py b/src/prefect/server/models/flow_run_states.py index 9ec181609f11f..433f480cec242 100644 --- a/src/prefect/server/models/flow_run_states.py +++ b/src/prefect/server/models/flow_run_states.py @@ -10,10 +10,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server.database import orm_models +from prefect.server.database.dependencies import db_injector +from prefect.server.database.interface import PrefectDBInterface +@db_injector async def read_flow_run_state( - session: AsyncSession, flow_run_state_id: UUID + db: PrefectDBInterface, session: AsyncSession, flow_run_state_id: UUID ) -> Union[orm_models.FlowRunState, None]: """ Reads a flow run state by id. @@ -26,11 +29,12 @@ async def read_flow_run_state( orm_models.FlowRunState: the flow state """ - return await session.get(orm_models.FlowRunState, flow_run_state_id) + return await session.get(db.FlowRunState, flow_run_state_id) +@db_injector async def read_flow_run_states( - session: AsyncSession, flow_run_id: UUID + db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID ) -> Sequence[orm_models.FlowRunState]: """ Reads flow runs states for a flow run. @@ -44,15 +48,17 @@ async def read_flow_run_states( """ query = ( - select(orm_models.FlowRunState) + select(db.FlowRunState) .filter_by(flow_run_id=flow_run_id) - .order_by(orm_models.FlowRunState.timestamp) + .order_by(db.FlowRunState.timestamp) ) result = await session.execute(query) return result.scalars().unique().all() +@db_injector async def delete_flow_run_state( + db: PrefectDBInterface, session: AsyncSession, flow_run_state_id: UUID, ) -> bool: @@ -68,8 +74,6 @@ async def delete_flow_run_state( """ result = await session.execute( - delete(orm_models.FlowRunState).where( - orm_models.FlowRunState.id == flow_run_state_id - ) + delete(db.FlowRunState).where(db.FlowRunState.id == flow_run_state_id) ) return result.rowcount > 0 diff --git a/src/prefect/server/models/flow_runs.py b/src/prefect/server/models/flow_runs.py index a454fff8f1d47..87d6a92b136bb 100644 --- a/src/prefect/server/models/flow_runs.py +++ b/src/prefect/server/models/flow_runs.py @@ -30,9 +30,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas from prefect.logging.loggers import get_logger -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.exceptions import ObjectNotFoundError from prefect.server.orchestration.core_policy import MinimalFlowPolicy from prefect.server.orchestration.global_policy import GlobalFlowPolicy @@ -63,7 +61,7 @@ async def create_flow_run( db: PrefectDBInterface, session: AsyncSession, flow_run: schemas.core.FlowRun, - orchestration_parameters: Optional[dict] = None, + orchestration_parameters: Optional[dict[str, Any]] = None, ) -> orm_models.FlowRun: """Creates a new flow run. @@ -98,36 +96,34 @@ async def create_flow_run( # if no idempotency key was provided, create the run directly if not flow_run.idempotency_key: - model = orm_models.FlowRun(**flow_run_dict) + model = db.FlowRun(**flow_run_dict) session.add(model) await session.flush() # otherwise let the database take care of enforcing idempotency else: insert_stmt = ( - db.insert(orm_models.FlowRun) + db.queries.insert(db.FlowRun) .values(**flow_run_dict) .on_conflict_do_nothing( - index_elements=db.flow_run_unique_upsert_columns, + index_elements=db.orm.flow_run_unique_upsert_columns, ) ) await session.execute(insert_stmt) # read the run to see if idempotency was applied or not query = ( - sa.select(orm_models.FlowRun) + sa.select(db.FlowRun) .where( sa.and_( - orm_models.FlowRun.flow_id == flow_run.flow_id, - orm_models.FlowRun.idempotency_key == flow_run.idempotency_key, + db.FlowRun.flow_id == flow_run.flow_id, + db.FlowRun.idempotency_key == flow_run.idempotency_key, ) ) .limit(1) .execution_options(populate_existing=True) .options( - selectinload(orm_models.FlowRun.work_queue).selectinload( - orm_models.WorkQueue.work_pool - ) + selectinload(db.FlowRun.work_queue).selectinload(db.WorkQueue.work_pool) ) ) result = await session.execute(query) @@ -146,7 +142,9 @@ async def create_flow_run( return model +@db_injector async def update_flow_run( + db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID, flow_run: schemas.actions.FlowRunUpdate, @@ -163,8 +161,8 @@ async def update_flow_run( bool: whether or not matching rows were found to update """ update_stmt = ( - sa.update(orm_models.FlowRun) - .where(orm_models.FlowRun.id == flow_run_id) + sa.update(db.FlowRun) + .where(db.FlowRun.id == flow_run_id) # exclude_unset=True allows us to only update values provided by # the user, ignoring any defaults on the model .values(**flow_run.model_dump_for_orm(exclude_unset=True)) @@ -173,7 +171,9 @@ async def update_flow_run( return result.rowcount > 0 +@db_injector async def read_flow_run( + db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID, for_update: bool = False, @@ -189,12 +189,10 @@ async def read_flow_run( orm_models.FlowRun: the flow run """ select = ( - sa.select(orm_models.FlowRun) - .where(orm_models.FlowRun.id == flow_run_id) + sa.select(db.FlowRun) + .where(db.FlowRun.id == flow_run_id) .options( - selectinload(orm_models.FlowRun.work_queue).selectinload( - orm_models.WorkQueue.work_pool - ) + selectinload(db.FlowRun.work_queue).selectinload(db.WorkQueue.work_pool) ) ) @@ -206,6 +204,7 @@ async def read_flow_run( async def _apply_flow_run_filters( + db: PrefectDBInterface, query: Select[T], flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -222,52 +221,52 @@ async def _apply_flow_run_filters( query = query.where(flow_run_filter.as_sql_filter()) if deployment_filter: - deployment_exists_clause = select(orm_models.Deployment).where( - orm_models.Deployment.id == orm_models.FlowRun.deployment_id, + deployment_exists_clause = select(db.Deployment).where( + db.Deployment.id == db.FlowRun.deployment_id, deployment_filter.as_sql_filter(), ) query = query.where(deployment_exists_clause.exists()) if work_pool_filter: - work_pool_exists_clause = select(orm_models.WorkPool).where( - orm_models.WorkQueue.id == orm_models.FlowRun.work_queue_id, - orm_models.WorkPool.id == orm_models.WorkQueue.work_pool_id, + work_pool_exists_clause = select(db.WorkPool).where( + db.WorkQueue.id == db.FlowRun.work_queue_id, + db.WorkPool.id == db.WorkQueue.work_pool_id, work_pool_filter.as_sql_filter(), ) query = query.where(work_pool_exists_clause.exists()) if work_queue_filter: - work_queue_exists_clause = select(orm_models.WorkQueue).where( - orm_models.WorkQueue.id == orm_models.FlowRun.work_queue_id, + work_queue_exists_clause = select(db.WorkQueue).where( + db.WorkQueue.id == db.FlowRun.work_queue_id, work_queue_filter.as_sql_filter(), ) query = query.where(work_queue_exists_clause.exists()) if flow_filter or task_run_filter: flow_or_task_run_exists_clause: Union[ - Select[Tuple[orm_models.Flow]], - Select[Tuple[orm_models.TaskRun]], + Select[Tuple[db.Flow]], + Select[Tuple[db.TaskRun]], ] if flow_filter: - flow_or_task_run_exists_clause = select(orm_models.Flow).where( - orm_models.Flow.id == orm_models.FlowRun.flow_id, + flow_or_task_run_exists_clause = select(db.Flow).where( + db.Flow.id == db.FlowRun.flow_id, flow_filter.as_sql_filter(), ) if task_run_filter: if not flow_filter: - flow_or_task_run_exists_clause = select(orm_models.TaskRun).where( - orm_models.TaskRun.flow_run_id == orm_models.FlowRun.id + flow_or_task_run_exists_clause = select(db.TaskRun).where( + db.TaskRun.flow_run_id == db.FlowRun.id ) else: flow_or_task_run_exists_clause = flow_or_task_run_exists_clause.join( - orm_models.TaskRun, - orm_models.TaskRun.flow_run_id == orm_models.FlowRun.id, + db.TaskRun, + db.TaskRun.flow_run_id == db.FlowRun.id, ) flow_or_task_run_exists_clause = flow_or_task_run_exists_clause.where( - orm_models.FlowRun.id == orm_models.TaskRun.flow_run_id, + db.FlowRun.id == db.TaskRun.flow_run_id, task_run_filter.as_sql_filter(), ) @@ -276,7 +275,9 @@ async def _apply_flow_run_filters( return query +@db_injector async def read_flow_runs( + db: PrefectDBInterface, session: AsyncSession, columns: Optional[List] = None, flow_filter: Optional[schemas.filters.FlowFilter] = None, @@ -307,12 +308,10 @@ async def read_flow_runs( List[orm_models.FlowRun]: flow runs """ query = ( - select(orm_models.FlowRun) - .order_by(sort.as_sql_sort()) + select(db.FlowRun) + .order_by(*sort.as_sql_sort()) .options( - selectinload(orm_models.FlowRun.work_queue).selectinload( - orm_models.WorkQueue.work_pool - ) + selectinload(db.FlowRun.work_queue).selectinload(db.WorkQueue.work_pool) ) ) @@ -320,6 +319,7 @@ async def read_flow_runs( query = query.options(load_only(*columns)) query = await _apply_flow_run_filters( + db, query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -428,7 +428,9 @@ async def read_task_run_dependencies( return dependency_graph +@db_injector async def count_flow_runs( + db: PrefectDBInterface, session: AsyncSession, flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -451,9 +453,10 @@ async def count_flow_runs( int: count of flow runs """ - query = select(sa.func.count(sa.text("*"))).select_from(orm_models.FlowRun) + query = select(sa.func.count(None)).select_from(db.FlowRun) query = await _apply_flow_run_filters( + db, query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -467,7 +470,10 @@ async def count_flow_runs( return result.scalar_one() -async def delete_flow_run(session: AsyncSession, flow_run_id: UUID) -> bool: +@db_injector +async def delete_flow_run( + db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID +) -> bool: """ Delete a flow run by flow_run_id, handling concurrency limits if applicable. @@ -489,7 +495,7 @@ async def delete_flow_run(session: AsyncSession, flow_run_id: UUID) -> bool: # Delete the flow run result = await session.execute( - delete(orm_models.FlowRun).where(orm_models.FlowRun.id == flow_run_id) + delete(db.FlowRun).where(db.FlowRun.id == flow_run_id) ) return result.rowcount > 0 @@ -594,7 +600,7 @@ async def read_flow_run_graph( db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID, - since: datetime.datetime = datetime.datetime.min, + since: pendulum.DateTime = pendulum.DateTime.min, ) -> Graph: """Given a flow run, return the graph of it's task and subflow runs. If a `since` datetime is provided, only return items that may have changed since that time.""" @@ -643,7 +649,9 @@ async def with_system_labels_for_flow_run( return parent_labels | default_labels | user_supplied_labels +@db_injector async def update_flow_run_labels( + db: PrefectDBInterface, session: AsyncSession, flow_run_id: UUID, labels: KeyValueLabels, @@ -669,8 +677,8 @@ async def update_flow_run_labels( try: # Update the flow run with merged labels result = await session.execute( - sa.update(orm_models.FlowRun) - .where(orm_models.FlowRun.id == flow_run_id) + sa.update(db.FlowRun) + .where(db.FlowRun.id == flow_run_id) .values(labels=updated_labels) ) success = result.rowcount > 0 diff --git a/src/prefect/server/models/flows.py b/src/prefect/server/models/flows.py index ba3243468b01d..32369cfc897a5 100644 --- a/src/prefect/server/models/flows.py +++ b/src/prefect/server/models/flows.py @@ -12,9 +12,7 @@ from sqlalchemy.sql import Select import prefect.server.schemas as schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models T = TypeVar("T", bound=tuple) @@ -37,19 +35,17 @@ async def create_flow( """ insert_stmt = ( - db.insert(orm_models.Flow) + db.queries.insert(db.Flow) .values(**flow.model_dump_for_orm(exclude_unset=True)) .on_conflict_do_nothing( - index_elements=db.flow_unique_upsert_columns, + index_elements=db.orm.flow_unique_upsert_columns, ) ) await session.execute(insert_stmt) query = ( - sa.select(orm_models.Flow) - .where( - orm_models.Flow.name == flow.name, - ) + sa.select(db.Flow) + .where(db.Flow.name == flow.name) .limit(1) .execution_options(populate_existing=True) ) @@ -58,7 +54,9 @@ async def create_flow( return model +@db_injector async def update_flow( + db: PrefectDBInterface, session: AsyncSession, flow_id: UUID, flow: schemas.actions.FlowUpdate, @@ -75,8 +73,8 @@ async def update_flow( bool: whether or not matching rows were found to update """ update_stmt = ( - sa.update(orm_models.Flow) - .where(orm_models.Flow.id == flow_id) + sa.update(db.Flow) + .where(db.Flow.id == flow_id) # exclude_unset=True allows us to only update values provided by # the user, ignoring any defaults on the model .values(**flow.model_dump_for_orm(exclude_unset=True)) @@ -85,7 +83,10 @@ async def update_flow( return result.rowcount > 0 -async def read_flow(session: AsyncSession, flow_id: UUID) -> Optional[orm_models.Flow]: +@db_injector +async def read_flow( + db: PrefectDBInterface, session: AsyncSession, flow_id: UUID +) -> Optional[orm_models.Flow]: """ Reads a flow by id. @@ -96,11 +97,12 @@ async def read_flow(session: AsyncSession, flow_id: UUID) -> Optional[orm_models Returns: orm_models.Flow: the flow """ - return await session.get(orm_models.Flow, flow_id) + return await session.get(db.Flow, flow_id) +@db_injector async def read_flow_by_name( - session: AsyncSession, name: str + db: PrefectDBInterface, session: AsyncSession, name: str ) -> Optional[orm_models.Flow]: """ Reads a flow by name. @@ -113,11 +115,10 @@ async def read_flow_by_name( orm_models.Flow: the flow """ - result = await session.execute(select(orm_models.Flow).filter_by(name=name)) + result = await session.execute(select(db.Flow).filter_by(name=name)) return result.scalar() -@db_injector async def _apply_flow_filters( db: PrefectDBInterface, query: Select[T], @@ -135,8 +136,8 @@ async def _apply_flow_filters( query = query.where(flow_filter.as_sql_filter()) if deployment_filter or work_pool_filter: - deployment_exists_clause = select(orm_models.Deployment).where( - orm_models.Deployment.flow_id == orm_models.Flow.id + deployment_exists_clause = select(db.Deployment).where( + db.Deployment.flow_id == db.Flow.id ) if deployment_filter: @@ -146,19 +147,19 @@ async def _apply_flow_filters( if work_pool_filter: deployment_exists_clause = deployment_exists_clause.join( - orm_models.WorkQueue, - orm_models.WorkQueue.id == orm_models.Deployment.work_queue_id, + db.WorkQueue, + db.WorkQueue.id == db.Deployment.work_queue_id, ) deployment_exists_clause = deployment_exists_clause.join( - orm_models.WorkPool, - orm_models.WorkPool.id == orm_models.WorkQueue.work_pool_id, + db.WorkPool, + db.WorkPool.id == db.WorkQueue.work_pool_id, ).where(work_pool_filter.as_sql_filter()) query = query.where(deployment_exists_clause.exists()) if flow_run_filter or task_run_filter: - flow_run_exists_clause = select(orm_models.FlowRun).where( - orm_models.FlowRun.flow_id == orm_models.Flow.id + flow_run_exists_clause = select(db.FlowRun).where( + db.FlowRun.flow_id == db.Flow.id ) if flow_run_filter: @@ -168,8 +169,8 @@ async def _apply_flow_filters( if task_run_filter: flow_run_exists_clause = flow_run_exists_clause.join( - orm_models.TaskRun, - orm_models.TaskRun.flow_run_id == orm_models.FlowRun.id, + db.TaskRun, + db.TaskRun.flow_run_id == db.FlowRun.id, ).where(task_run_filter.as_sql_filter()) query = query.where(flow_run_exists_clause.exists()) @@ -177,7 +178,9 @@ async def _apply_flow_filters( return query +@db_injector async def read_flows( + db: PrefectDBInterface, session: AsyncSession, flow_filter: Union[schemas.filters.FlowFilter, None] = None, flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None, @@ -205,9 +208,10 @@ async def read_flows( List[orm_models.Flow]: flows """ - query = select(orm_models.Flow).order_by(sort.as_sql_sort()) + query = select(db.Flow).order_by(*sort.as_sql_sort()) query = await _apply_flow_filters( + db, query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -226,7 +230,9 @@ async def read_flows( return result.scalars().unique().all() +@db_injector async def count_flows( + db: PrefectDBInterface, session: AsyncSession, flow_filter: Union[schemas.filters.FlowFilter, None] = None, flow_run_filter: Union[schemas.filters.FlowRunFilter, None] = None, @@ -249,9 +255,10 @@ async def count_flows( int: count of flows """ - query = select(sa.func.count(sa.text("*"))).select_from(orm_models.Flow) + query = select(sa.func.count(None)).select_from(db.Flow) query = await _apply_flow_filters( + db, query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -264,7 +271,10 @@ async def count_flows( return result.scalar_one() -async def delete_flow(session: AsyncSession, flow_id: UUID) -> bool: +@db_injector +async def delete_flow( + db: PrefectDBInterface, session: AsyncSession, flow_id: UUID +) -> bool: """ Delete a flow by id. @@ -276,18 +286,16 @@ async def delete_flow(session: AsyncSession, flow_id: UUID) -> bool: bool: whether or not the flow was deleted """ - result = await session.execute( - delete(orm_models.Flow).where(orm_models.Flow.id == flow_id) - ) + result = await session.execute(delete(db.Flow).where(db.Flow.id == flow_id)) return result.rowcount > 0 +@db_injector async def read_flow_labels( + db: PrefectDBInterface, session: AsyncSession, flow_id: UUID, ) -> Union[schemas.core.KeyValueLabels, None]: - result = await session.execute( - select(orm_models.Flow.labels).where(orm_models.Flow.id == flow_id) - ) + result = await session.execute(select(db.Flow.labels).where(db.Flow.id == flow_id)) return result.scalar() diff --git a/src/prefect/server/models/logs.py b/src/prefect/server/models/logs.py index ab407617fd59d..ce5eac18eace5 100644 --- a/src/prefect/server/models/logs.py +++ b/src/prefect/server/models/logs.py @@ -10,9 +10,7 @@ import prefect.server.schemas as schemas from prefect.logging import get_logger -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.schemas.actions import LogCreate from prefect.utilities.collections import batched_iterable @@ -51,7 +49,7 @@ async def create_logs( """ try: await session.execute( - db.insert(orm_models.Log).values([log.model_dump() for log in logs]) + db.queries.insert(db.Log).values([log.model_dump() for log in logs]) ) except RuntimeError as exc: if "can't create new thread at interpreter shutdown" in str(exc): @@ -63,7 +61,9 @@ async def create_logs( raise +@db_injector async def read_logs( + db: PrefectDBInterface, session: AsyncSession, log_filter: schemas.filters.LogFilter, offset: Optional[int] = None, @@ -84,9 +84,7 @@ async def read_logs( Returns: List[orm_models.Log]: the matching logs """ - query = ( - select(orm_models.Log).order_by(sort.as_sql_sort()).offset(offset).limit(limit) - ) + query = select(db.Log).order_by(*sort.as_sql_sort()).offset(offset).limit(limit) if log_filter: query = query.where(log_filter.as_sql_filter()) diff --git a/src/prefect/server/models/saved_searches.py b/src/prefect/server/models/saved_searches.py index 31a20647e88ea..35d8f7394a575 100644 --- a/src/prefect/server/models/saved_searches.py +++ b/src/prefect/server/models/saved_searches.py @@ -11,9 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.schemas as schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models @db_injector @@ -36,10 +34,10 @@ async def create_saved_search( """ insert_stmt = ( - db.insert(orm_models.SavedSearch) + db.queries.insert(db.SavedSearch) .values(**saved_search.model_dump_for_orm(exclude_unset=True)) .on_conflict_do_update( - index_elements=db.saved_search_unique_upsert_columns, + index_elements=db.orm.saved_search_unique_upsert_columns, set_=saved_search.model_dump_for_orm(include={"filters"}), ) ) @@ -47,9 +45,9 @@ async def create_saved_search( await session.execute(insert_stmt) query = ( - sa.select(orm_models.SavedSearch) + sa.select(db.SavedSearch) .where( - orm_models.SavedSearch.name == saved_search.name, + db.SavedSearch.name == saved_search.name, ) .execution_options(populate_existing=True) ) @@ -59,8 +57,9 @@ async def create_saved_search( return model +@db_injector async def read_saved_search( - session: AsyncSession, saved_search_id: UUID + db: PrefectDBInterface, session: AsyncSession, saved_search_id: UUID ) -> Union[orm_models.SavedSearch, None]: """ Reads a SavedSearch by id. @@ -73,11 +72,12 @@ async def read_saved_search( orm_models.SavedSearch: the SavedSearch """ - return await session.get(orm_models.SavedSearch, saved_search_id) + return await session.get(db.SavedSearch, saved_search_id) +@db_injector async def read_saved_search_by_name( - session: AsyncSession, name: str + db: PrefectDBInterface, session: AsyncSession, name: str ) -> Union[orm_models.SavedSearch, None]: """ Reads a SavedSearch by name. @@ -90,14 +90,14 @@ async def read_saved_search_by_name( orm_models.SavedSearch: the SavedSearch """ result = await session.execute( - select(orm_models.SavedSearch) - .where(orm_models.SavedSearch.name == name) - .limit(1) + select(db.SavedSearch).where(db.SavedSearch.name == name).limit(1) ) return result.scalar() +@db_injector async def read_saved_searches( + db: PrefectDBInterface, session: AsyncSession, offset: Optional[int] = None, limit: Optional[int] = None, @@ -114,7 +114,7 @@ async def read_saved_searches( List[orm_models.SavedSearch]: SavedSearches """ - query = select(orm_models.SavedSearch).order_by(orm_models.SavedSearch.name) + query = select(db.SavedSearch).order_by(db.SavedSearch.name) if offset is not None: query = query.offset(offset) @@ -125,7 +125,10 @@ async def read_saved_searches( return result.scalars().unique().all() -async def delete_saved_search(session: AsyncSession, saved_search_id: UUID) -> bool: +@db_injector +async def delete_saved_search( + db: PrefectDBInterface, session: AsyncSession, saved_search_id: UUID +) -> bool: """ Delete a SavedSearch by id. @@ -138,8 +141,6 @@ async def delete_saved_search(session: AsyncSession, saved_search_id: UUID) -> b """ result = await session.execute( - delete(orm_models.SavedSearch).where( - orm_models.SavedSearch.id == saved_search_id - ) + delete(db.SavedSearch).where(db.SavedSearch.id == saved_search_id) ) return result.rowcount > 0 diff --git a/src/prefect/server/models/task_run_states.py b/src/prefect/server/models/task_run_states.py index ef55826db21c1..68c81e3980c43 100644 --- a/src/prefect/server/models/task_run_states.py +++ b/src/prefect/server/models/task_run_states.py @@ -9,11 +9,12 @@ from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database import orm_models +from prefect.server.database import PrefectDBInterface, db_injector, orm_models +@db_injector async def read_task_run_state( - session: AsyncSession, task_run_state_id: UUID + db: PrefectDBInterface, session: AsyncSession, task_run_state_id: UUID ) -> Union[orm_models.TaskRunState, None]: """ Reads a task run state by id. @@ -26,11 +27,12 @@ async def read_task_run_state( orm_models.TaskRunState: the task state """ - return await session.get(orm_models.TaskRunState, task_run_state_id) + return await session.get(db.TaskRunState, task_run_state_id) +@db_injector async def read_task_run_states( - session: AsyncSession, task_run_id: UUID + db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID ) -> Sequence[orm_models.TaskRunState]: """ Reads task runs states for a task run. @@ -44,15 +46,17 @@ async def read_task_run_states( """ query = ( - select(orm_models.TaskRunState) + select(db.TaskRunState) .filter_by(task_run_id=task_run_id) - .order_by(orm_models.TaskRunState.timestamp) + .order_by(db.TaskRunState.timestamp) ) result = await session.execute(query) return result.scalars().unique().all() -async def delete_task_run_state(session: AsyncSession, task_run_state_id: UUID) -> bool: +async def delete_task_run_state( + db: PrefectDBInterface, session: AsyncSession, task_run_state_id: UUID +) -> bool: """ Delete a task run state by id. @@ -65,8 +69,6 @@ async def delete_task_run_state(session: AsyncSession, task_run_state_id: UUID) """ result = await session.execute( - delete(orm_models.TaskRunState).where( - orm_models.TaskRunState.id == task_run_state_id - ) + delete(db.TaskRunState).where(db.TaskRunState.id == task_run_state_id) ) return result.rowcount > 0 diff --git a/src/prefect/server/models/task_runs.py b/src/prefect/server/models/task_runs.py index 9f1be084612ac..cbeeead3e04c7 100644 --- a/src/prefect/server/models/task_runs.py +++ b/src/prefect/server/models/task_runs.py @@ -16,9 +16,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas from prefect.logging import get_logger -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.exceptions import ObjectNotFoundError from prefect.server.orchestration.core_policy import ( BackgroundTaskPolicy, @@ -66,7 +64,7 @@ async def create_task_run( # if a dynamic key exists, we need to guard against conflicts if task_run.flow_run_id: insert_stmt = ( - db.insert(orm_models.TaskRun) + db.queries.insert(db.TaskRun) .values( created=now, **task_run.model_dump_for_orm( @@ -74,18 +72,18 @@ async def create_task_run( ), ) .on_conflict_do_nothing( - index_elements=db.task_run_unique_upsert_columns, + index_elements=db.orm.task_run_unique_upsert_columns, ) ) await session.execute(insert_stmt) query = ( - sa.select(orm_models.TaskRun) + sa.select(db.TaskRun) .where( sa.and_( - orm_models.TaskRun.flow_run_id == task_run.flow_run_id, - orm_models.TaskRun.task_key == task_run.task_key, - orm_models.TaskRun.dynamic_key == task_run.dynamic_key, + db.TaskRun.flow_run_id == task_run.flow_run_id, + db.TaskRun.task_key == task_run.task_key, + db.TaskRun.dynamic_key == task_run.dynamic_key, ) ) .limit(1) @@ -96,12 +94,12 @@ async def create_task_run( else: # Upsert on (task_key, dynamic_key) application logic. query = ( - sa.select(orm_models.TaskRun) + sa.select(db.TaskRun) .where( sa.and_( - orm_models.TaskRun.flow_run_id.is_(None), - orm_models.TaskRun.task_key == task_run.task_key, - orm_models.TaskRun.dynamic_key == task_run.dynamic_key, + db.TaskRun.flow_run_id.is_(None), + db.TaskRun.task_key == task_run.task_key, + db.TaskRun.dynamic_key == task_run.dynamic_key, ) ) .limit(1) @@ -112,7 +110,7 @@ async def create_task_run( model = result.scalar() if model is None: - model = orm_models.TaskRun( + model = db.TaskRun( created=now, **task_run.model_dump_for_orm( exclude={"state", "created"}, exclude_unset=True @@ -134,7 +132,9 @@ async def create_task_run( return model +@db_injector async def update_task_run( + db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID, task_run: schemas.actions.TaskRunUpdate, @@ -151,8 +151,8 @@ async def update_task_run( bool: whether or not matching rows were found to update """ update_stmt = ( - sa.update(orm_models.TaskRun) - .where(orm_models.TaskRun.id == task_run_id) + sa.update(db.TaskRun) + .where(db.TaskRun.id == task_run_id) # exclude_unset=True allows us to only update values provided by # the user, ignoring any defaults on the model .values(**task_run.model_dump_for_orm(exclude_unset=True)) @@ -161,8 +161,9 @@ async def update_task_run( return result.rowcount > 0 +@db_injector async def read_task_run( - session: AsyncSession, task_run_id: UUID + db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID ) -> Union[orm_models.TaskRun, None]: """ Read a task run by id. @@ -175,11 +176,12 @@ async def read_task_run( orm_models.TaskRun: the task run """ - model = await session.get(orm_models.TaskRun, task_run_id) + model = await session.get(db.TaskRun, task_run_id) return model async def _apply_task_run_filters( + db: PrefectDBInterface, query: Select[T], flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -206,7 +208,7 @@ async def _apply_task_run_filters( [flow_filter, deployment_filter, work_pool_filter, work_queue_filter] ) ): - query = query.where(orm_models.TaskRun.flow_run_id.in_(flow_run_filter.id.any_)) + query = query.where(db.TaskRun.flow_run_id.in_(flow_run_filter.id.any_)) return query @@ -217,8 +219,8 @@ async def _apply_task_run_filters( or work_pool_filter or work_queue_filter ): - exists_clause = select(orm_models.FlowRun).where( - orm_models.FlowRun.id == orm_models.TaskRun.flow_run_id + exists_clause = select(db.FlowRun).where( + db.FlowRun.id == db.TaskRun.flow_run_id ) if flow_run_filter: @@ -226,28 +228,28 @@ async def _apply_task_run_filters( if flow_filter: exists_clause = exists_clause.join( - orm_models.Flow, - orm_models.Flow.id == orm_models.FlowRun.flow_id, + db.Flow, + db.Flow.id == db.FlowRun.flow_id, ).where(flow_filter.as_sql_filter()) if deployment_filter: exists_clause = exists_clause.join( - orm_models.Deployment, - orm_models.Deployment.id == orm_models.FlowRun.deployment_id, + db.Deployment, + db.Deployment.id == db.FlowRun.deployment_id, ).where(deployment_filter.as_sql_filter()) if work_queue_filter: exists_clause = exists_clause.join( - orm_models.WorkQueue, - orm_models.WorkQueue.id == orm_models.FlowRun.work_queue_id, + db.WorkQueue, + db.WorkQueue.id == db.FlowRun.work_queue_id, ).where(work_queue_filter.as_sql_filter()) if work_pool_filter: exists_clause = exists_clause.join( - orm_models.WorkPool, + db.WorkPool, sa.and_( - orm_models.WorkPool.id == orm_models.WorkQueue.work_pool_id, - orm_models.WorkQueue.id == orm_models.FlowRun.work_queue_id, + db.WorkPool.id == db.WorkQueue.work_pool_id, + db.WorkQueue.id == db.FlowRun.work_queue_id, ), ).where(work_pool_filter.as_sql_filter()) @@ -256,7 +258,9 @@ async def _apply_task_run_filters( return query +@db_injector async def read_task_runs( + db: PrefectDBInterface, session: AsyncSession, flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -283,9 +287,10 @@ async def read_task_runs( List[orm_models.TaskRun]: the task runs """ - query = select(orm_models.TaskRun).order_by(sort.as_sql_sort()) + query = select(db.TaskRun).order_by(*sort.as_sql_sort()) query = await _apply_task_run_filters( + db, query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -304,7 +309,9 @@ async def read_task_runs( return result.scalars().unique().all() +@db_injector async def count_task_runs( + db: PrefectDBInterface, session: AsyncSession, flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -324,9 +331,10 @@ async def count_task_runs( int: count of task runs """ - query = select(sa.func.count(sa.text("*"))).select_from(orm_models.TaskRun) + query = select(sa.func.count(None)).select_from(db.TaskRun) query = await _apply_task_run_filters( + db, query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -338,7 +346,9 @@ async def count_task_runs( return result.scalar_one() +@db_injector async def count_task_runs_by_state( + db: PrefectDBInterface, session: AsyncSession, flow_filter: Optional[schemas.filters.FlowFilter] = None, flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, @@ -359,15 +369,13 @@ async def count_task_runs_by_state( """ base_query = ( - select( - orm_models.TaskRun.state_type, - sa.func.count(sa.text("*")).label("count"), - ) - .select_from(orm_models.TaskRun) - .group_by(orm_models.TaskRun.state_type) + select(db.TaskRun.state_type, sa.func.count(None).label("count")) + .select_from(db.TaskRun) + .group_by(db.TaskRun.state_type) ) query = await _apply_task_run_filters( + db, base_query, flow_filter=flow_filter, flow_run_filter=flow_run_filter, @@ -385,7 +393,10 @@ async def count_task_runs_by_state( return counts -async def delete_task_run(session: AsyncSession, task_run_id: UUID) -> bool: +@db_injector +async def delete_task_run( + db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID +) -> bool: """ Delete a task run by id. @@ -398,7 +409,7 @@ async def delete_task_run(session: AsyncSession, task_run_id: UUID) -> bool: """ result = await session.execute( - delete(orm_models.TaskRun).where(orm_models.TaskRun.id == task_run_id) + delete(db.TaskRun).where(db.TaskRun.id == task_run_id) ) return result.rowcount > 0 diff --git a/src/prefect/server/models/variables.py b/src/prefect/server/models/variables.py index 64022b1cd5cb2..fa4dca8a40e2b 100644 --- a/src/prefect/server/models/variables.py +++ b/src/prefect/server/models/variables.py @@ -4,13 +4,14 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database import orm_models +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.schemas import filters, sorting from prefect.server.schemas.actions import VariableCreate, VariableUpdate +@db_injector async def create_variable( - session: AsyncSession, variable: VariableCreate + db: PrefectDBInterface, session: AsyncSession, variable: VariableCreate ) -> orm_models.Variable: """ Create a variable @@ -22,40 +23,44 @@ async def create_variable( Returns: orm_models.Variable """ - model = orm_models.Variable(**variable.model_dump()) + model = db.Variable(**variable.model_dump()) session.add(model) await session.flush() return model +@db_injector async def read_variable( - session: AsyncSession, variable_id: UUID + db: PrefectDBInterface, session: AsyncSession, variable_id: UUID ) -> Optional[orm_models.Variable]: """ Reads a variable by id. """ - query = sa.select(orm_models.Variable).where(orm_models.Variable.id == variable_id) + query = sa.select(db.Variable).where(db.Variable.id == variable_id) result = await session.execute(query) return result.scalar() +@db_injector async def read_variable_by_name( - session: AsyncSession, name: str + db: PrefectDBInterface, session: AsyncSession, name: str ) -> Optional[orm_models.Variable]: """ Reads a variable by name. """ - query = sa.select(orm_models.Variable).where(orm_models.Variable.name == name) + query = sa.select(db.Variable).where(db.Variable.name == name) result = await session.execute(query) return result.scalar() +@db_injector async def read_variables( + db: PrefectDBInterface, session: AsyncSession, variable_filter: Optional[filters.VariableFilter] = None, sort: sorting.VariableSort = sorting.VariableSort.NAME_ASC, @@ -65,7 +70,7 @@ async def read_variables( """ Read variables, applying filers. """ - query = sa.select(orm_models.Variable).order_by(sort.as_sql_sort()) + query = sa.select(db.Variable).order_by(*sort.as_sql_sort()) if variable_filter: query = query.where(variable_filter.as_sql_filter()) @@ -79,14 +84,17 @@ async def read_variables( return result.scalars().unique().all() +@db_injector async def count_variables( - session: AsyncSession, variable_filter: Optional[filters.VariableFilter] = None + db: PrefectDBInterface, + session: AsyncSession, + variable_filter: Optional[filters.VariableFilter] = None, ) -> int: """ Count variables, applying filters. """ - query = sa.select(sa.func.count()).select_from(orm_models.Variable) + query = sa.select(sa.func.count()).select_from(db.Variable) if variable_filter: query = query.where(variable_filter.as_sql_filter()) @@ -95,15 +103,19 @@ async def count_variables( return result.scalar_one() +@db_injector async def update_variable( - session: AsyncSession, variable_id: UUID, variable: VariableUpdate + db: PrefectDBInterface, + session: AsyncSession, + variable_id: UUID, + variable: VariableUpdate, ) -> bool: """ Updates a variable by id. """ query = ( - sa.update(orm_models.Variable) - .where(orm_models.Variable.id == variable_id) + sa.update(db.Variable) + .where(db.Variable.id == variable_id) .values(**variable.model_dump_for_orm(exclude_unset=True)) ) @@ -111,15 +123,16 @@ async def update_variable( return result.rowcount > 0 +@db_injector async def update_variable_by_name( - session: AsyncSession, name: str, variable: VariableUpdate + db: PrefectDBInterface, session: AsyncSession, name: str, variable: VariableUpdate ) -> bool: """ Updates a variable by name. """ query = ( - sa.update(orm_models.Variable) - .where(orm_models.Variable.name == name) + sa.update(db.Variable) + .where(db.Variable.name == name) .values(**variable.model_dump_for_orm(exclude_unset=True)) ) @@ -127,23 +140,29 @@ async def update_variable_by_name( return result.rowcount > 0 -async def delete_variable(session: AsyncSession, variable_id: UUID) -> bool: +@db_injector +async def delete_variable( + db: PrefectDBInterface, session: AsyncSession, variable_id: UUID +) -> bool: """ Delete a variable by id. """ - query = sa.delete(orm_models.Variable).where(orm_models.Variable.id == variable_id) + query = sa.delete(db.Variable).where(db.Variable.id == variable_id) result = await session.execute(query) return result.rowcount > 0 -async def delete_variable_by_name(session: AsyncSession, name: str) -> bool: +@db_injector +async def delete_variable_by_name( + db: PrefectDBInterface, session: AsyncSession, name: str +) -> bool: """ Delete a variable by name. """ - query = sa.delete(orm_models.Variable).where(orm_models.Variable.name == name) + query = sa.delete(db.Variable).where(db.Variable.name == name) result = await session.execute(query) return result.rowcount > 0 diff --git a/src/prefect/server/models/work_queues.py b/src/prefect/server/models/work_queues.py index 4c1345083455e..8a0d2ccbbd3ea 100644 --- a/src/prefect/server/models/work_queues.py +++ b/src/prefect/server/models/work_queues.py @@ -24,9 +24,7 @@ import prefect.server.models as models import prefect.server.schemas as schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.events.clients import PrefectServerEventsClient from prefect.server.exceptions import ObjectNotFoundError from prefect.server.models.events import work_queue_status_event @@ -41,7 +39,9 @@ WORK_QUEUE_LAST_POLLED_TIMEOUT = datetime.timedelta(seconds=60) +@db_injector async def create_work_queue( + db: PrefectDBInterface, session: AsyncSession, work_queue: Union[schemas.core.WorkQueue, schemas.actions.WorkQueueCreate], ) -> orm_models.WorkQueue: @@ -89,8 +89,8 @@ async def create_work_queue( # This will make the new queue the lowest priority if data["priority"] is None: # Set the priority to be the first priority value that isn't already taken - priorities_query = sa.select(orm_models.WorkQueue.priority).where( - orm_models.WorkQueue.work_pool_id == data["work_pool_id"] + priorities_query = sa.select(db.WorkQueue.priority).where( + db.WorkQueue.work_pool_id == data["work_pool_id"] ) priorities = (await session.execute(priorities_query)).scalars().all() @@ -109,7 +109,7 @@ async def create_work_queue( data["priority"] = priority - model = orm_models.WorkQueue(**data) + model = db.WorkQueue(**data) session.add(model) await session.flush() @@ -125,8 +125,11 @@ async def create_work_queue( return model +@db_injector async def read_work_queue( - session: AsyncSession, work_queue_id: Union[UUID, PrefectUUID] + db: PrefectDBInterface, + session: AsyncSession, + work_queue_id: Union[UUID, PrefectUUID], ) -> Optional[orm_models.WorkQueue]: """ Reads a WorkQueue by id. @@ -139,11 +142,12 @@ async def read_work_queue( orm_models.WorkQueue: the WorkQueue """ - return await session.get(orm_models.WorkQueue, work_queue_id) + return await session.get(db.WorkQueue, work_queue_id) +@db_injector async def read_work_queue_by_name( - session: AsyncSession, name: str + db: PrefectDBInterface, session: AsyncSession, name: str ) -> Optional[orm_models.WorkQueue]: """ Reads a WorkQueue by id. @@ -160,16 +164,18 @@ async def read_work_queue_by_name( ) # Logic to make sure this functionality doesn't break during migration if default_work_pool is not None: - query = select(orm_models.WorkQueue).filter_by( + query = select(db.WorkQueue).filter_by( name=name, work_pool_id=default_work_pool.id ) else: - query = select(orm_models.WorkQueue).filter_by(name=name) + query = select(db.WorkQueue).filter_by(name=name) result = await session.execute(query) return result.scalar() +@db_injector async def read_work_queues( + db: PrefectDBInterface, session: AsyncSession, offset: Optional[int] = None, limit: Optional[int] = None, @@ -187,7 +193,7 @@ async def read_work_queues( Sequence[orm_models.WorkQueue]: WorkQueues """ - query = select(orm_models.WorkQueue).order_by(orm_models.WorkQueue.name) + query = select(db.WorkQueue).order_by(db.WorkQueue.name) if offset is not None: query = query.offset(offset) @@ -206,7 +212,9 @@ def is_last_polled_recent(last_polled: Optional[pendulum.DateTime]) -> bool: return (pendulum.now("UTC") - last_polled) <= WORK_QUEUE_LAST_POLLED_TIMEOUT +@db_injector async def update_work_queue( + db: PrefectDBInterface, session: AsyncSession, work_queue_id: UUID, work_queue: schemas.actions.WorkQueueUpdate, @@ -260,8 +268,8 @@ async def update_work_queue( update_data["status"] = schemas.statuses.WorkQueueStatus.READY update_stmt = ( - sa.update(orm_models.WorkQueue) - .where(orm_models.WorkQueue.id == work_queue_id) + sa.update(db.WorkQueue) + .where(db.WorkQueue.id == work_queue_id) .values(**update_data) ) result = await session.execute(update_stmt) @@ -276,7 +284,10 @@ async def update_work_queue( return updated -async def delete_work_queue(session: AsyncSession, work_queue_id: UUID) -> bool: +@db_injector +async def delete_work_queue( + db: PrefectDBInterface, session: AsyncSession, work_queue_id: UUID +) -> bool: """ Delete a WorkQueue by id. @@ -288,7 +299,7 @@ async def delete_work_queue(session: AsyncSession, work_queue_id: UUID) -> bool: bool: whether or not the WorkQueue was deleted """ result = await session.execute( - delete(orm_models.WorkQueue).where(orm_models.WorkQueue.id == work_queue_id) + delete(db.WorkQueue).where(db.WorkQueue.id == work_queue_id) ) return result.rowcount > 0 @@ -496,7 +507,9 @@ async def read_work_queue_status( ) +@db_injector async def record_work_queue_polls( + db: PrefectDBInterface, session: AsyncSession, polled_work_queue_ids: Sequence[UUID], ready_work_queue_ids: Sequence[UUID], @@ -507,15 +520,15 @@ async def record_work_queue_polls( if polled_work_queue_ids: await session.execute( - sa.update(orm_models.WorkQueue) - .where(orm_models.WorkQueue.id.in_(polled_work_queue_ids)) + sa.update(db.WorkQueue) + .where(db.WorkQueue.id.in_(polled_work_queue_ids)) .values(last_polled=polled) ) if ready_work_queue_ids: await session.execute( - sa.update(orm_models.WorkQueue) - .where(orm_models.WorkQueue.id.in_(ready_work_queue_ids)) + sa.update(db.WorkQueue) + .where(db.WorkQueue.id.in_(ready_work_queue_ids)) .values(last_polled=polled, status=WorkQueueStatus.READY) ) @@ -541,9 +554,7 @@ async def mark_work_queues_ready( async with db.session_context(begin_transaction=True) as session: newly_ready_work_queues = await session.execute( - sa.select(orm_models.WorkQueue).where( - orm_models.WorkQueue.id.in_(ready_work_queue_ids) - ) + sa.select(db.WorkQueue).where(db.WorkQueue.id.in_(ready_work_queue_ids)) ) events = [ @@ -570,8 +581,8 @@ async def mark_work_queues_not_ready( async with db.session_context(begin_transaction=True) as session: await session.execute( - sa.update(orm_models.WorkQueue) - .where(orm_models.WorkQueue.id.in_(work_queue_ids)) + sa.update(db.WorkQueue) + .where(db.WorkQueue.id.in_(work_queue_ids)) .values(status=WorkQueueStatus.NOT_READY) ) @@ -581,9 +592,7 @@ async def mark_work_queues_not_ready( async with db.session_context(begin_transaction=True) as session: newly_unready_work_queues = await session.execute( - sa.select(orm_models.WorkQueue).where( - orm_models.WorkQueue.id.in_(work_queue_ids) - ) + sa.select(db.WorkQueue).where(db.WorkQueue.id.in_(work_queue_ids)) ) events = [ diff --git a/src/prefect/server/models/workers.py b/src/prefect/server/models/workers.py index b2d78044a7eb0..fbce8befe3997 100644 --- a/src/prefect/server/models/workers.py +++ b/src/prefect/server/models/workers.py @@ -21,9 +21,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.schemas as schemas -from prefect.server.database import orm_models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector, orm_models from prefect.server.events.clients import PrefectServerEventsClient from prefect.server.exceptions import ObjectNotFoundError from prefect.server.models.events import work_pool_status_event @@ -41,7 +39,9 @@ # ----------------------------------------------------- +@db_injector async def create_work_pool( + db: PrefectDBInterface, session: AsyncSession, work_pool: Union[schemas.core.WorkPool, schemas.actions.WorkPoolCreate], ) -> orm_models.WorkPool: @@ -59,7 +59,7 @@ async def create_work_pool( """ - pool = orm_models.WorkPool(**work_pool.model_dump()) + pool = db.WorkPool(**work_pool.model_dump()) if pool.type != "prefect-agent": if pool.is_paused: @@ -84,8 +84,9 @@ async def create_work_pool( return pool +@db_injector async def read_work_pool( - session: AsyncSession, work_pool_id: UUID + db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID ) -> Optional[orm_models.WorkPool]: """ Reads a WorkPool by id. @@ -97,17 +98,14 @@ async def read_work_pool( Returns: orm_models.WorkPool: the WorkPool """ - query = ( - sa.select(orm_models.WorkPool) - .where(orm_models.WorkPool.id == work_pool_id) - .limit(1) - ) + query = sa.select(db.WorkPool).where(db.WorkPool.id == work_pool_id).limit(1) result = await session.execute(query) return result.scalar() +@db_injector async def read_work_pool_by_name( - session: AsyncSession, work_pool_name: str + db: PrefectDBInterface, session: AsyncSession, work_pool_name: str ) -> Optional[orm_models.WorkPool]: """ Reads a WorkPool by name. @@ -119,16 +117,14 @@ async def read_work_pool_by_name( Returns: orm_models.WorkPool: the WorkPool """ - query = ( - sa.select(orm_models.WorkPool) - .where(orm_models.WorkPool.name == work_pool_name) - .limit(1) - ) + query = sa.select(db.WorkPool).where(db.WorkPool.name == work_pool_name).limit(1) result = await session.execute(query) return result.scalar() +@db_injector async def read_work_pools( + db: PrefectDBInterface, session: AsyncSession, work_pool_filter: Optional[schemas.filters.WorkPoolFilter] = None, offset: Optional[int] = None, @@ -145,7 +141,7 @@ async def read_work_pools( List[orm_models.WorkPool]: worker configs """ - query = select(orm_models.WorkPool).order_by(orm_models.WorkPool.name) + query = select(db.WorkPool).order_by(db.WorkPool.name) if work_pool_filter is not None: query = query.where(work_pool_filter.as_sql_filter()) @@ -158,7 +154,9 @@ async def read_work_pools( return result.scalars().unique().all() +@db_injector async def count_work_pools( + db: PrefectDBInterface, session: AsyncSession, work_pool_filter: Optional[schemas.filters.WorkPoolFilter] = None, ) -> int: @@ -172,7 +170,7 @@ async def count_work_pools( int: the count of work pools matching the criteria """ - query = select(sa.func.count()).select_from(orm_models.WorkPool) + query = select(sa.func.count()).select_from(db.WorkPool) if work_pool_filter is not None: query = query.where(work_pool_filter.as_sql_filter()) @@ -181,7 +179,9 @@ async def count_work_pools( return result.scalar_one() +@db_injector async def update_work_pool( + db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID, work_pool: schemas.actions.WorkPoolUpdate, @@ -243,8 +243,8 @@ async def update_work_pool( update_data["last_transitioned_status_at"] = pendulum.now("UTC") update_stmt = ( - sa.update(orm_models.WorkPool) - .where(orm_models.WorkPool.id == work_pool_id) + sa.update(db.WorkPool) + .where(db.WorkPool.id == work_pool_id) .values(**update_data) ) result = await session.execute(update_stmt) @@ -267,7 +267,10 @@ async def update_work_pool( return updated -async def delete_work_pool(session: AsyncSession, work_pool_id: UUID) -> bool: +@db_injector +async def delete_work_pool( + db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID +) -> bool: """ Delete a WorkPool by id. @@ -280,7 +283,7 @@ async def delete_work_pool(session: AsyncSession, work_pool_id: UUID) -> bool: """ result = await session.execute( - delete(orm_models.WorkPool).where(orm_models.WorkPool.id == work_pool_id) + delete(db.WorkPool).where(db.WorkPool.id == work_pool_id) ) return result.rowcount > 0 @@ -337,7 +340,9 @@ async def get_scheduled_flow_runs( # ----------------------------------------------------- +@db_injector async def create_work_queue( + db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID, work_queue: schemas.actions.WorkQueueCreate, @@ -357,8 +362,8 @@ async def create_work_queue( data = work_queue.model_dump(exclude={"work_pool_id"}) if work_queue.priority is None: # Set the priority to be the first priority value that isn't already taken - priorities_query = sa.select(orm_models.WorkQueue.priority).where( - orm_models.WorkQueue.work_pool_id == work_pool_id + priorities_query = sa.select(db.WorkQueue.priority).where( + db.WorkQueue.work_pool_id == work_pool_id ) priorities = (await session.execute(priorities_query)).scalars().all() @@ -377,7 +382,7 @@ async def create_work_queue( data["priority"] = priority - model = orm_models.WorkQueue(**data, work_pool_id=work_pool_id) + model = db.WorkQueue(**data, work_pool_id=work_pool_id) session.add(model) await session.flush() @@ -392,7 +397,9 @@ async def create_work_queue( return model +@db_injector async def bulk_update_work_queue_priorities( + db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID, new_priorities: Dict[UUID, int], @@ -420,9 +427,9 @@ async def bulk_update_work_queue_priorities( # get all the work queues, sorted by priority work_queues_query = ( - sa.select(orm_models.WorkQueue) - .where(orm_models.WorkQueue.work_pool_id == work_pool_id) - .order_by(orm_models.WorkQueue.priority.asc()) + sa.select(db.WorkQueue) + .where(db.WorkQueue.work_pool_id == work_pool_id) + .order_by(db.WorkQueue.priority.asc()) ) result = await session.execute(work_queues_query) all_work_queues = result.scalars().all() @@ -456,7 +463,9 @@ async def bulk_update_work_queue_priorities( await session.flush() +@db_injector async def read_work_queues( + db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID, work_queue_filter: Optional[schemas.filters.WorkQueueFilter] = None, @@ -479,9 +488,9 @@ async def read_work_queues( """ query = ( - sa.select(orm_models.WorkQueue) - .where(orm_models.WorkQueue.work_pool_id == work_pool_id) - .order_by(orm_models.WorkQueue.priority.asc()) + sa.select(db.WorkQueue) + .where(db.WorkQueue.work_pool_id == work_pool_id) + .order_by(db.WorkQueue.priority.asc()) ) if work_queue_filter is not None: @@ -495,7 +504,9 @@ async def read_work_queues( return result.scalars().unique().all() +@db_injector async def read_work_queue( + db: PrefectDBInterface, session: AsyncSession, work_queue_id: Union[UUID, PrefectUUID], ) -> Optional[orm_models.WorkQueue]: @@ -510,10 +521,12 @@ async def read_work_queue( orm_models.WorkQueue: the WorkQueue """ - return await session.get(orm_models.WorkQueue, work_queue_id) + return await session.get(db.WorkQueue, work_queue_id) +@db_injector async def read_work_queue_by_name( + db: PrefectDBInterface, session: AsyncSession, work_pool_name: str, work_queue_name: str, @@ -530,14 +543,14 @@ async def read_work_queue_by_name( orm_models.WorkQueue: the WorkQueue """ query = ( - sa.select(orm_models.WorkQueue) + sa.select(db.WorkQueue) .join( - orm_models.WorkPool, - orm_models.WorkPool.id == orm_models.WorkQueue.work_pool_id, + db.WorkPool, + db.WorkPool.id == db.WorkQueue.work_pool_id, ) .where( - orm_models.WorkPool.name == work_pool_name, - orm_models.WorkQueue.name == work_queue_name, + db.WorkPool.name == work_pool_name, + db.WorkQueue.name == work_queue_name, ) .limit(1) ) @@ -545,7 +558,9 @@ async def read_work_queue_by_name( return result.scalar() +@db_injector async def update_work_queue( + db: PrefectDBInterface, session: AsyncSession, work_queue_id: UUID, work_queue: schemas.actions.WorkQueueUpdate, @@ -573,7 +588,7 @@ async def update_work_queue( update_values = work_queue.model_dump_for_orm(exclude_unset=True) if "is_paused" in update_values: - if (wq := await session.get(orm_models.WorkQueue, work_queue_id)) is None: + if (wq := await session.get(db.WorkQueue, work_queue_id)) is None: return False # Only update the status to paused if it's not already paused. This ensures a work queue that is already @@ -601,8 +616,8 @@ async def update_work_queue( update_values["status"] = schemas.statuses.WorkQueueStatus.READY update_stmt = ( - sa.update(orm_models.WorkQueue) - .where(orm_models.WorkQueue.id == work_queue_id) + sa.update(db.WorkQueue) + .where(db.WorkQueue.id == work_queue_id) .values(update_values) ) result = await session.execute(update_stmt) @@ -611,7 +626,7 @@ async def update_work_queue( if updated: if "priority" in update_values or "status" in update_values: - updated_work_queue = await session.get(orm_models.WorkQueue, work_queue_id) + updated_work_queue = await session.get(db.WorkQueue, work_queue_id) assert updated_work_queue if "priority" in update_values: @@ -627,7 +642,9 @@ async def update_work_queue( return updated +@db_injector async def delete_work_queue( + db: PrefectDBInterface, session: AsyncSession, work_queue_id: UUID, ) -> bool: @@ -642,7 +659,7 @@ async def delete_work_queue( bool: whether or not the WorkQueue was deleted """ - work_queue = await session.get(orm_models.WorkQueue, work_queue_id) + work_queue = await session.get(db.WorkQueue, work_queue_id) if work_queue is None: return False @@ -673,7 +690,9 @@ async def delete_work_queue( # ----------------------------------------------------- +@db_injector async def read_workers( + db: PrefectDBInterface, session: AsyncSession, work_pool_id: UUID, worker_filter: Optional[schemas.filters.WorkerFilter] = None, @@ -681,9 +700,9 @@ async def read_workers( offset: Optional[int] = None, ) -> Sequence[orm_models.Worker]: query = ( - sa.select(orm_models.Worker) - .where(orm_models.Worker.work_pool_id == work_pool_id) - .order_by(orm_models.Worker.last_heartbeat_time.desc()) + sa.select(db.Worker) + .where(db.Worker.work_pool_id == work_pool_id) + .order_by(db.Worker.last_heartbeat_time.desc()) .limit(limit) ) @@ -735,12 +754,12 @@ async def worker_heartbeat( update_values["heartbeat_interval_seconds"] = heartbeat_interval_seconds insert_stmt = ( - db.insert(orm_models.Worker) + db.queries.insert(db.Worker) .values(**base_values, **update_values) .on_conflict_do_update( index_elements=[ - orm_models.Worker.work_pool_id, - orm_models.Worker.name, + db.Worker.work_pool_id, + db.Worker.name, ], set_=update_values, ) @@ -770,9 +789,9 @@ async def delete_worker( """ result = await session.execute( - delete(orm_models.Worker).where( - orm_models.Worker.work_pool_id == work_pool_id, - orm_models.Worker.name == worker_name, + delete(db.Worker).where( + db.Worker.work_pool_id == work_pool_id, + db.Worker.name == worker_name, ) ) diff --git a/src/prefect/server/orchestration/core_policy.py b/src/prefect/server/orchestration/core_policy.py index ac9281597d610..0856dfc7f2de6 100644 --- a/src/prefect/server/orchestration/core_policy.py +++ b/src/prefect/server/orchestration/core_policy.py @@ -15,8 +15,7 @@ from prefect.logging import get_logger from prefect.server import models -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.server.exceptions import ObjectNotFoundError from prefect.server.models import concurrency_limits, concurrency_limits_v2, deployments from prefect.server.orchestration.policies import BaseOrchestrationPolicy diff --git a/src/prefect/server/orchestration/rules.py b/src/prefect/server/orchestration/rules.py index 421e7f8fca2e3..b6610daaf1124 100644 --- a/src/prefect/server/orchestration/rules.py +++ b/src/prefect/server/orchestration/rules.py @@ -26,8 +26,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.logging import get_logger -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.server.exceptions import OrchestrationError from prefect.server.models import artifacts, flow_runs from prefect.server.schemas import core, states diff --git a/src/prefect/server/schemas/filters.py b/src/prefect/server/schemas/filters.py index f1633387e885f..e797e0cd1f736 100644 --- a/src/prefect/server/schemas/filters.py +++ b/src/prefect/server/schemas/filters.py @@ -4,14 +4,14 @@ Each filter schema includes logic for transforming itself into a SQL `where` clause. """ -from collections.abc import Sequence -from typing import TYPE_CHECKING, List, Optional +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Optional from uuid import UUID from pydantic import ConfigDict, Field import prefect.server.schemas as schemas -from prefect.server.database import orm_models +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.utilities.schemas.bases import PrefectBaseModel from prefect.types import DateTime from prefect.utilities.collections import AutoEnum @@ -20,7 +20,6 @@ if TYPE_CHECKING: import sqlalchemy as sa from sqlalchemy.dialects import postgresql - from sqlalchemy.sql.elements import BooleanClauseList else: sa = lazy_import("sqlalchemy") postgresql = lazy_import("sqlalchemy.dialects.postgresql") @@ -46,14 +45,17 @@ class PrefectFilterBaseModel(PrefectBaseModel): model_config = ConfigDict(extra="forbid") - def as_sql_filter(self) -> "BooleanClauseList": + @db_injector + def as_sql_filter(self, db: PrefectDBInterface) -> sa.ColumnElement[bool]: """Generate SQL filter from provided filter parameters. If no filters parameters are available, return a TRUE filter.""" - filters = self._get_filter_list() + filters = self._get_filter_list(db) if not filters: - return True + return sa.true() return sa.and_(*filters) - def _get_filter_list(self) -> List: + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: """Return a list of boolean filter statements based on filter parameters""" raise NotImplementedError("_get_filter_list must be implemented") @@ -66,24 +68,27 @@ class PrefectOperatorFilterBaseModel(PrefectFilterBaseModel): description="Operator for combining filter criteria. Defaults to 'and_'.", ) - def as_sql_filter(self) -> "BooleanClauseList": - filters = self._get_filter_list() + @db_injector + def as_sql_filter(self, db: PrefectDBInterface) -> sa.ColumnElement[bool]: + filters = self._get_filter_list(db) if not filters: - return True + return sa.true() return sa.and_(*filters) if self.operator == Operator.and_ else sa.or_(*filters) class FlowFilterId(PrefectFilterBaseModel): """Filter by `Flow.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of flow ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Flow.id.in_(self.any_)) + filters.append(db.Flow.id.in_(self.any_)) return filters @@ -95,21 +100,23 @@ class FlowFilterDeployment(PrefectOperatorFilterBaseModel): description="If true, only include flows without deployments", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.is_null_ is not None: deployments_subquery = ( - sa.select(orm_models.Deployment.flow_id).distinct().subquery() + sa.select(db.Deployment.flow_id).distinct().subquery() ) if self.is_null_: filters.append( - orm_models.Flow.id.not_in(sa.select(deployments_subquery.c.flow_id)) + db.Flow.id.not_in(sa.select(deployments_subquery.c.flow_id)) ) else: filters.append( - orm_models.Flow.id.in_(sa.select(deployments_subquery.c.flow_id)) + db.Flow.id.in_(sa.select(deployments_subquery.c.flow_id)) ) return filters @@ -118,7 +125,7 @@ def _get_filter_list(self) -> List: class FlowFilterName(PrefectFilterBaseModel): """Filter by `Flow.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of flow names to include", examples=[["my-flow-1", "my-flow-2"]], @@ -134,19 +141,21 @@ class FlowFilterName(PrefectFilterBaseModel): examples=["marvin"], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Flow.name.in_(self.any_)) + filters.append(db.Flow.name.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.Flow.name.ilike(f"%{self.like_}%")) + filters.append(db.Flow.name.ilike(f"%{self.like_}%")) return filters class FlowFilterTags(PrefectOperatorFilterBaseModel): """Filter by `Flow.tags`.""" - all_: Optional[List[str]] = Field( + all_: Optional[list[str]] = Field( default=None, examples=[["tag-1", "tag-2"]], description=( @@ -158,16 +167,14 @@ class FlowFilterTags(PrefectOperatorFilterBaseModel): default=None, description="If true, only include flows without tags" ) - def _get_filter_list(self) -> list[sa.ColumnElement[bool]]: - filters: list[sa.ColumnElement[bool]] = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.all_ is not None: - filters.append(orm_models.Flow.tags.has_all(_as_array(self.all_))) + filters.append(db.Flow.tags.has_all(_as_array(self.all_))) if self.is_null_ is not None: - filters.append( - orm_models.Flow.tags == [] - if self.is_null_ - else orm_models.Flow.tags != [] - ) + filters.append(db.Flow.tags == [] if self.is_null_ else db.Flow.tags != []) return filters @@ -187,8 +194,10 @@ class FlowFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `Flow.tags`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) @@ -205,26 +214,28 @@ def _get_filter_list(self) -> List: class FlowRunFilterId(PrefectFilterBaseModel): """Filter by `FlowRun.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of flow run ids to include" ) - not_any_: Optional[List[UUID]] = Field( + not_any_: Optional[list[UUID]] = Field( default=None, description="A list of flow run ids to exclude" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.id.in_(self.any_)) + filters.append(db.FlowRun.id.in_(self.any_)) if self.not_any_ is not None: - filters.append(orm_models.FlowRun.id.not_in(self.not_any_)) + filters.append(db.FlowRun.id.not_in(self.not_any_)) return filters class FlowRunFilterName(PrefectFilterBaseModel): """Filter by `FlowRun.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of flow run names to include", examples=[["my-flow-run-1", "my-flow-run-2"]], @@ -240,19 +251,21 @@ class FlowRunFilterName(PrefectFilterBaseModel): examples=["marvin"], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.name.in_(self.any_)) + filters.append(db.FlowRun.name.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.FlowRun.name.ilike(f"%{self.like_}%")) + filters.append(db.FlowRun.name.ilike(f"%{self.like_}%")) return filters class FlowRunFilterTags(PrefectOperatorFilterBaseModel): """Filter by `FlowRun.tags`.""" - all_: Optional[List[str]] = Field( + all_: Optional[list[str]] = Field( default=None, examples=[["tag-1", "tag-2"]], description=( @@ -261,7 +274,7 @@ class FlowRunFilterTags(PrefectOperatorFilterBaseModel): ), ) - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, examples=[["tag-1", "tag-2"]], description="A list of tags to include", @@ -271,20 +284,20 @@ class FlowRunFilterTags(PrefectOperatorFilterBaseModel): default=None, description="If true, only include flow runs without tags" ) - def _get_filter_list(self) -> list[sa.ColumnElement[bool]]: + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: def as_array(elems: Sequence[str]) -> sa.ColumnElement[Sequence[str]]: return sa.cast(postgresql.array(elems), type_=postgresql.ARRAY(sa.String())) filters: list[sa.ColumnElement[bool]] = [] if self.all_ is not None: - filters.append(orm_models.FlowRun.tags.has_all(as_array(self.all_))) + filters.append(db.FlowRun.tags.has_all(as_array(self.all_))) if self.any_ is not None: - filters.append(orm_models.FlowRun.tags.has_any(as_array(self.any_))) + filters.append(db.FlowRun.tags.has_any(as_array(self.any_))) if self.is_null_ is not None: filters.append( - orm_models.FlowRun.tags == [] - if self.is_null_ - else orm_models.FlowRun.tags != [] + db.FlowRun.tags == [] if self.is_null_ else db.FlowRun.tags != [] ) return filters @@ -292,7 +305,7 @@ def as_array(elems: Sequence[str]) -> sa.ColumnElement[Sequence[str]]: class FlowRunFilterDeploymentId(PrefectOperatorFilterBaseModel): """Filter by `FlowRun.deployment_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of flow run deployment ids to include" ) is_null_: Optional[bool] = Field( @@ -300,15 +313,17 @@ class FlowRunFilterDeploymentId(PrefectOperatorFilterBaseModel): description="If true, only include flow runs without deployment ids", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.deployment_id.in_(self.any_)) + filters.append(db.FlowRun.deployment_id.in_(self.any_)) if self.is_null_ is not None: filters.append( - orm_models.FlowRun.deployment_id.is_(None) + db.FlowRun.deployment_id.is_(None) if self.is_null_ - else orm_models.FlowRun.deployment_id.is_not(None) + else db.FlowRun.deployment_id.is_not(None) ) return filters @@ -316,7 +331,7 @@ def _get_filter_list(self) -> List: class FlowRunFilterWorkQueueName(PrefectOperatorFilterBaseModel): """Filter by `FlowRun.work_queue_name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of work queue names to include", examples=[["work_queue_1", "work_queue_2"]], @@ -326,15 +341,17 @@ class FlowRunFilterWorkQueueName(PrefectOperatorFilterBaseModel): description="If true, only include flow runs without work queue names", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.work_queue_name.in_(self.any_)) + filters.append(db.FlowRun.work_queue_name.in_(self.any_)) if self.is_null_ is not None: filters.append( - orm_models.FlowRun.work_queue_name.is_(None) + db.FlowRun.work_queue_name.is_(None) if self.is_null_ - else orm_models.FlowRun.work_queue_name.is_not(None) + else db.FlowRun.work_queue_name.is_not(None) ) return filters @@ -342,38 +359,42 @@ def _get_filter_list(self) -> List: class FlowRunFilterStateType(PrefectFilterBaseModel): """Filter by `FlowRun.state_type`.""" - any_: Optional[List[schemas.states.StateType]] = Field( + any_: Optional[list[schemas.states.StateType]] = Field( default=None, description="A list of flow run state types to include" ) - not_any_: Optional[List[schemas.states.StateType]] = Field( + not_any_: Optional[list[schemas.states.StateType]] = Field( default=None, description="A list of flow run state types to exclude" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.state_type.in_(self.any_)) + filters.append(db.FlowRun.state_type.in_(self.any_)) if self.not_any_ is not None: - filters.append(orm_models.FlowRun.state_type.not_in(self.not_any_)) + filters.append(db.FlowRun.state_type.not_in(self.not_any_)) return filters class FlowRunFilterStateName(PrefectFilterBaseModel): """Filter by `FlowRun.state_name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of flow run state names to include" ) - not_any_: Optional[List[str]] = Field( + not_any_: Optional[list[str]] = Field( default=None, description="A list of flow run state names to exclude" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.state_name.in_(self.any_)) + filters.append(db.FlowRun.state_name.in_(self.any_)) if self.not_any_ is not None: - filters.append(orm_models.FlowRun.state_name.not_in(self.not_any_)) + filters.append(db.FlowRun.state_name.not_in(self.not_any_)) return filters @@ -387,26 +408,34 @@ class FlowRunFilterState(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `FlowRun.state_name`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.type is not None: - filters.extend(self.type._get_filter_list()) + filter = self.type.as_sql_filter() + if isinstance(filter, sa.BinaryExpression): + filters.append(filter) if self.name is not None: - filters.extend(self.name._get_filter_list()) + filter = self.name.as_sql_filter() + if isinstance(filter, sa.BinaryExpression): + filters.append(filter) return filters class FlowRunFilterFlowVersion(PrefectFilterBaseModel): """Filter by `FlowRun.flow_version`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of flow run flow_versions to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.flow_version.in_(self.any_)) + filters.append(db.FlowRun.flow_version.in_(self.any_)) return filters @@ -425,17 +454,19 @@ class FlowRunFilterStartTime(PrefectFilterBaseModel): default=None, description="If true, only return flow runs without a start time" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.FlowRun.start_time <= self.before_) + filters.append(db.FlowRun.start_time <= self.before_) if self.after_ is not None: - filters.append(orm_models.FlowRun.start_time >= self.after_) + filters.append(db.FlowRun.start_time >= self.after_) if self.is_null_ is not None: filters.append( - orm_models.FlowRun.start_time.is_(None) + db.FlowRun.start_time.is_(None) if self.is_null_ - else orm_models.FlowRun.start_time.is_not(None) + else db.FlowRun.start_time.is_not(None) ) return filters @@ -455,17 +486,19 @@ class FlowRunFilterEndTime(PrefectFilterBaseModel): default=None, description="If true, only return flow runs without an end time" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.FlowRun.end_time <= self.before_) + filters.append(db.FlowRun.end_time <= self.before_) if self.after_ is not None: - filters.append(orm_models.FlowRun.end_time >= self.after_) + filters.append(db.FlowRun.end_time >= self.after_) if self.is_null_ is not None: filters.append( - orm_models.FlowRun.end_time.is_(None) + db.FlowRun.end_time.is_(None) if self.is_null_ - else orm_models.FlowRun.end_time.is_not(None) + else db.FlowRun.end_time.is_not(None) ) return filters @@ -482,12 +515,14 @@ class FlowRunFilterExpectedStartTime(PrefectFilterBaseModel): description="Only include flow runs scheduled to start at or after this time", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.FlowRun.expected_start_time <= self.before_) + filters.append(db.FlowRun.expected_start_time <= self.before_) if self.after_ is not None: - filters.append(orm_models.FlowRun.expected_start_time >= self.after_) + filters.append(db.FlowRun.expected_start_time >= self.after_) return filters @@ -509,36 +544,39 @@ class FlowRunFilterNextScheduledStartTime(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.FlowRun.next_scheduled_start_time <= self.before_) + filters.append(db.FlowRun.next_scheduled_start_time <= self.before_) if self.after_ is not None: - filters.append(orm_models.FlowRun.next_scheduled_start_time >= self.after_) + filters.append(db.FlowRun.next_scheduled_start_time >= self.after_) return filters class FlowRunFilterParentFlowRunId(PrefectOperatorFilterBaseModel): """Filter for subflows of a given flow run""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of parent flow run ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: filters.append( - orm_models.FlowRun.id.in_( - sa.select(orm_models.FlowRun.id) + db.FlowRun.id.in_( + sa.select(db.FlowRun.id) .join( - orm_models.TaskRun, + db.TaskRun, sa.and_( - orm_models.TaskRun.id - == orm_models.FlowRun.parent_task_run_id, + db.TaskRun.id == db.FlowRun.parent_task_run_id, ), ) - .where(orm_models.TaskRun.flow_run_id.in_(self.any_)) + .where(db.TaskRun.flow_run_id.in_(self.any_)) ) ) return filters @@ -547,7 +585,7 @@ def _get_filter_list(self) -> List: class FlowRunFilterParentTaskRunId(PrefectOperatorFilterBaseModel): """Filter by `FlowRun.parent_task_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of flow run parent_task_run_ids to include" ) is_null_: Optional[bool] = Field( @@ -555,15 +593,17 @@ class FlowRunFilterParentTaskRunId(PrefectOperatorFilterBaseModel): description="If true, only include flow runs without parent_task_run_id", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.parent_task_run_id.in_(self.any_)) + filters.append(db.FlowRun.parent_task_run_id.in_(self.any_)) if self.is_null_ is not None: filters.append( - orm_models.FlowRun.parent_task_run_id.is_(None) + db.FlowRun.parent_task_run_id.is_(None) if self.is_null_ - else orm_models.FlowRun.parent_task_run_id.is_not(None) + else db.FlowRun.parent_task_run_id.is_not(None) ) return filters @@ -571,19 +611,21 @@ def _get_filter_list(self) -> List: class FlowRunFilterIdempotencyKey(PrefectFilterBaseModel): """Filter by FlowRun.idempotency_key.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of flow run idempotency keys to include" ) - not_any_: Optional[List[str]] = Field( + not_any_: Optional[list[str]] = Field( default=None, description="A list of flow run idempotency keys to exclude" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.FlowRun.idempotency_key.in_(self.any_)) + filters.append(db.FlowRun.idempotency_key.in_(self.any_)) if self.not_any_ is not None: - filters.append(orm_models.FlowRun.idempotency_key.not_in(self.not_any_)) + filters.append(db.FlowRun.idempotency_key.not_in(self.not_any_)) return filters @@ -653,8 +695,10 @@ def only_filters_on_id(self): and self.idempotency_key is None ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) @@ -691,7 +735,7 @@ def _get_filter_list(self) -> List: class TaskRunFilterFlowRunId(PrefectOperatorFilterBaseModel): """Filter by `TaskRun.flow_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of task run flow run ids to include" ) @@ -699,36 +743,40 @@ class TaskRunFilterFlowRunId(PrefectOperatorFilterBaseModel): default=False, description="Filter for task runs with None as their flow run id" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.is_null_ is True: - filters.append(orm_models.TaskRun.flow_run_id.is_(None)) + filters.append(db.TaskRun.flow_run_id.is_(None)) elif self.is_null_ is False and self.any_ is None: - filters.append(orm_models.TaskRun.flow_run_id.is_not(None)) + filters.append(db.TaskRun.flow_run_id.is_not(None)) else: if self.any_ is not None: - filters.append(orm_models.TaskRun.flow_run_id.in_(self.any_)) + filters.append(db.TaskRun.flow_run_id.in_(self.any_)) return filters class TaskRunFilterId(PrefectFilterBaseModel): """Filter by `TaskRun.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of task run ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.TaskRun.id.in_(self.any_)) + filters.append(db.TaskRun.id.in_(self.any_)) return filters class TaskRunFilterName(PrefectFilterBaseModel): """Filter by `TaskRun.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of task run names to include", examples=[["my-task-run-1", "my-task-run-2"]], @@ -744,19 +792,21 @@ class TaskRunFilterName(PrefectFilterBaseModel): examples=["marvin"], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.TaskRun.name.in_(self.any_)) + filters.append(db.TaskRun.name.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.TaskRun.name.ilike(f"%{self.like_}%")) + filters.append(db.TaskRun.name.ilike(f"%{self.like_}%")) return filters class TaskRunFilterTags(PrefectOperatorFilterBaseModel): """Filter by `TaskRun.tags`.""" - all_: Optional[List[str]] = Field( + all_: Optional[list[str]] = Field( default=None, examples=[["tag-1", "tag-2"]], description=( @@ -768,15 +818,15 @@ class TaskRunFilterTags(PrefectOperatorFilterBaseModel): default=None, description="If true, only include task runs without tags" ) - def _get_filter_list(self) -> list[sa.ColumnElement[bool]]: + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: filters: list[sa.ColumnElement[bool]] = [] if self.all_ is not None: - filters.append(orm_models.TaskRun.tags.has_all(_as_array(self.all_))) + filters.append(db.TaskRun.tags.has_all(_as_array(self.all_))) if self.is_null_ is not None: filters.append( - orm_models.TaskRun.tags == [] - if self.is_null_ - else orm_models.TaskRun.tags != [] + db.TaskRun.tags == [] if self.is_null_ else db.TaskRun.tags != [] ) return filters @@ -784,28 +834,32 @@ def _get_filter_list(self) -> list[sa.ColumnElement[bool]]: class TaskRunFilterStateType(PrefectFilterBaseModel): """Filter by `TaskRun.state_type`.""" - any_: Optional[List[schemas.states.StateType]] = Field( + any_: Optional[list[schemas.states.StateType]] = Field( default=None, description="A list of task run state types to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.TaskRun.state_type.in_(self.any_)) + filters.append(db.TaskRun.state_type.in_(self.any_)) return filters class TaskRunFilterStateName(PrefectFilterBaseModel): """Filter by `TaskRun.state_name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of task run state names to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.TaskRun.state_name.in_(self.any_)) + filters.append(db.TaskRun.state_name.in_(self.any_)) return filters @@ -819,12 +873,18 @@ class TaskRunFilterState(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `TaskRun.state_name`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.type is not None: - filters.extend(self.type._get_filter_list()) + filter = self.type.as_sql_filter() + if isinstance(filter, sa.BinaryExpression): + filters.append(filter) if self.name is not None: - filters.extend(self.name._get_filter_list()) + filter = self.name.as_sql_filter() + if isinstance(filter, sa.BinaryExpression): + filters.append(filter) return filters @@ -839,12 +899,14 @@ class TaskRunFilterSubFlowRuns(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.exists_ is True: - filters.append(orm_models.TaskRun.subflow_run.has()) + filters.append(db.TaskRun.subflow_run.has()) elif self.exists_ is False: - filters.append(sa.not_(orm_models.TaskRun.subflow_run.has())) + filters.append(sa.not_(db.TaskRun.subflow_run.has())) return filters @@ -863,17 +925,19 @@ class TaskRunFilterStartTime(PrefectFilterBaseModel): default=None, description="If true, only return task runs without a start time" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.TaskRun.start_time <= self.before_) + filters.append(db.TaskRun.start_time <= self.before_) if self.after_ is not None: - filters.append(orm_models.TaskRun.start_time >= self.after_) + filters.append(db.TaskRun.start_time >= self.after_) if self.is_null_ is not None: filters.append( - orm_models.TaskRun.start_time.is_(None) + db.TaskRun.start_time.is_(None) if self.is_null_ - else orm_models.TaskRun.start_time.is_not(None) + else db.TaskRun.start_time.is_not(None) ) return filters @@ -890,12 +954,14 @@ class TaskRunFilterExpectedStartTime(PrefectFilterBaseModel): description="Only include task runs expected to start at or after this time", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.TaskRun.expected_start_time <= self.before_) + filters.append(db.TaskRun.expected_start_time <= self.before_) if self.after_ is not None: - filters.append(orm_models.TaskRun.expected_start_time >= self.after_) + filters.append(db.TaskRun.expected_start_time >= self.after_) return filters @@ -927,8 +993,10 @@ class TaskRunFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `TaskRun.flow_run_id`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) @@ -953,21 +1021,23 @@ def _get_filter_list(self) -> List: class DeploymentFilterId(PrefectFilterBaseModel): """Filter by `Deployment.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of deployment ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Deployment.id.in_(self.any_)) + filters.append(db.Deployment.id.in_(self.any_)) return filters class DeploymentFilterName(PrefectFilterBaseModel): """Filter by `Deployment.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of deployment names to include", examples=[["my-deployment-1", "my-deployment-2"]], @@ -983,12 +1053,14 @@ class DeploymentFilterName(PrefectFilterBaseModel): examples=["marvin"], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Deployment.name.in_(self.any_)) + filters.append(db.Deployment.name.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.Deployment.name.ilike(f"%{self.like_}%")) + filters.append(db.Deployment.name.ilike(f"%{self.like_}%")) return filters @@ -1003,13 +1075,15 @@ class DeploymentOrFlowNameFilter(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.like_ is not None: - deployment_name_filter = orm_models.Deployment.name.ilike(f"%{self.like_}%") + deployment_name_filter = db.Deployment.name.ilike(f"%{self.like_}%") - flow_name_filter = orm_models.Deployment.flow.has( - orm_models.Flow.name.ilike(f"%{self.like_}%") + flow_name_filter = db.Deployment.flow.has( + db.Flow.name.ilike(f"%{self.like_}%") ) filters.append(sa.or_(deployment_name_filter, flow_name_filter)) return filters @@ -1023,26 +1097,30 @@ class DeploymentFilterPaused(PrefectFilterBaseModel): description="Only returns where deployment is/is not paused", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.eq_ is not None: - filters.append(orm_models.Deployment.paused.is_(self.eq_)) + filters.append(db.Deployment.paused.is_(self.eq_)) return filters class DeploymentFilterWorkQueueName(PrefectFilterBaseModel): """Filter by `Deployment.work_queue_name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of work queue names to include", examples=[["work_queue_1", "work_queue_2"]], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Deployment.work_queue_name.in_(self.any_)) + filters.append(db.Deployment.work_queue_name.in_(self.any_)) return filters @@ -1063,7 +1141,7 @@ class DeploymentFilterConcurrencyLimit(PrefectFilterBaseModel): description="If true, only include deployments without a concurrency limit", ) - def _get_filter_list(self) -> List: + def _get_filter_list(self, db: PrefectDBInterface) -> list[sa.ColumnElement[bool]]: # This used to filter on an `int` column that was moved to a `ForeignKey` relationship # This filter is now deprecated rather than support filtering on the new relationship return [] @@ -1072,7 +1150,7 @@ def _get_filter_list(self) -> List: class DeploymentFilterTags(PrefectOperatorFilterBaseModel): """Filter by `Deployment.tags`.""" - all_: Optional[List[str]] = Field( + all_: Optional[list[str]] = Field( default=None, examples=[["tag-1", "tag-2"]], description=( @@ -1084,15 +1162,15 @@ class DeploymentFilterTags(PrefectOperatorFilterBaseModel): default=None, description="If true, only include deployments without tags" ) - def _get_filter_list(self) -> list[sa.ColumnElement[bool]]: + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: filters: list[sa.ColumnElement[bool]] = [] if self.all_ is not None: - filters.append(orm_models.Deployment.tags.has_all(_as_array(self.all_))) + filters.append(db.Deployment.tags.has_all(_as_array(self.all_))) if self.is_null_ is not None: filters.append( - orm_models.Deployment.tags == [] - if self.is_null_ - else orm_models.Deployment.tags != [] + db.Deployment.tags == [] if self.is_null_ else db.Deployment.tags != [] ) return filters @@ -1124,8 +1202,10 @@ class DeploymentFilter(PrefectOperatorFilterBaseModel): deprecated=True, ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) if self.name is not None: @@ -1150,10 +1230,12 @@ class DeploymentScheduleFilterActive(PrefectFilterBaseModel): description="Only returns where deployment schedule is/is not active", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.eq_ is not None: - filters.append(orm_models.DeploymentSchedule.active.is_(self.eq_)) + filters.append(db.DeploymentSchedule.active.is_(self.eq_)) return filters @@ -1164,8 +1246,10 @@ class DeploymentScheduleFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `DeploymentSchedule.active`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.active is not None: filters.append(self.active.as_sql_filter()) @@ -1176,16 +1260,18 @@ def _get_filter_list(self) -> List: class LogFilterName(PrefectFilterBaseModel): """Filter by `Log.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of log names to include", examples=[["prefect.logger.flow_runs", "prefect.logger.task_runs"]], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Log.name.in_(self.any_)) + filters.append(db.Log.name.in_(self.any_)) return filters @@ -1204,12 +1290,14 @@ class LogFilterLevel(PrefectFilterBaseModel): examples=[50], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.ge_ is not None: - filters.append(orm_models.Log.level >= self.ge_) + filters.append(db.Log.level >= self.ge_) if self.le_ is not None: - filters.append(orm_models.Log.level <= self.le_) + filters.append(db.Log.level <= self.le_) return filters @@ -1225,33 +1313,37 @@ class LogFilterTimestamp(PrefectFilterBaseModel): description="Only include logs with a timestamp at or after this time", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.Log.timestamp <= self.before_) + filters.append(db.Log.timestamp <= self.before_) if self.after_ is not None: - filters.append(orm_models.Log.timestamp >= self.after_) + filters.append(db.Log.timestamp >= self.after_) return filters class LogFilterFlowRunId(PrefectFilterBaseModel): """Filter by `Log.flow_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of flow run IDs to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Log.flow_run_id.in_(self.any_)) + filters.append(db.Log.flow_run_id.in_(self.any_)) return filters class LogFilterTaskRunId(PrefectFilterBaseModel): """Filter by `Log.task_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of task run IDs to include" ) @@ -1260,15 +1352,17 @@ class LogFilterTaskRunId(PrefectFilterBaseModel): description="If true, only include logs without a task run id", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Log.task_run_id.in_(self.any_)) + filters.append(db.Log.task_run_id.in_(self.any_)) if self.is_null_ is not None: filters.append( - orm_models.Log.task_run_id.is_(None) + db.Log.task_run_id.is_(None) if self.is_null_ - else orm_models.Log.task_run_id.is_not(None) + else db.Log.task_run_id.is_not(None) ) return filters @@ -1289,8 +1383,10 @@ class LogFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `Log.task_run_id`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.level is not None: filters.append(self.level.as_sql_filter()) @@ -1335,24 +1431,28 @@ class BlockTypeFilterName(PrefectFilterBaseModel): examples=["marvin"], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.like_ is not None: - filters.append(orm_models.BlockType.name.ilike(f"%{self.like_}%")) + filters.append(db.BlockType.name.ilike(f"%{self.like_}%")) return filters class BlockTypeFilterSlug(PrefectFilterBaseModel): """Filter by `BlockType.slug`""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of slugs to match" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.BlockType.slug.in_(self.any_)) + filters.append(db.BlockType.slug.in_(self.any_)) return filters @@ -1368,8 +1468,10 @@ class BlockTypeFilter(PrefectFilterBaseModel): default=None, description="Filter criteria for `BlockType.slug`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.name is not None: filters.append(self.name.as_sql_filter()) @@ -1382,35 +1484,39 @@ def _get_filter_list(self) -> List: class BlockSchemaFilterBlockTypeId(PrefectFilterBaseModel): """Filter by `BlockSchema.block_type_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of block type ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.BlockSchema.block_type_id.in_(self.any_)) + filters.append(db.BlockSchema.block_type_id.in_(self.any_)) return filters class BlockSchemaFilterId(PrefectFilterBaseModel): """Filter by BlockSchema.id""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of IDs to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.BlockSchema.id.in_(self.any_)) + filters.append(db.BlockSchema.id.in_(self.any_)) return filters class BlockSchemaFilterCapabilities(PrefectFilterBaseModel): """Filter by `BlockSchema.capabilities`""" - all_: Optional[List[str]] = Field( + all_: Optional[list[str]] = Field( default=None, examples=[["write-storage", "read-storage"]], description=( @@ -1419,30 +1525,30 @@ class BlockSchemaFilterCapabilities(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> list[sa.ColumnElement[bool]]: + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: filters: list[sa.ColumnElement[bool]] = [] if self.all_ is not None: - filters.append( - orm_models.BlockSchema.capabilities.has_all(_as_array(self.all_)) - ) + filters.append(db.BlockSchema.capabilities.has_all(_as_array(self.all_))) return filters class BlockSchemaFilterVersion(PrefectFilterBaseModel): """Filter by `BlockSchema.capabilities`""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, examples=[["2.0.0", "2.1.0"]], description="A list of block schema versions.", ) - def _get_filter_list(self) -> List: - pass - - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnElement[bool]] = [] if self.any_ is not None: - filters.append(orm_models.BlockSchema.version.in_(self.any_)) + filters.append(db.BlockSchema.version.in_(self.any_)) return filters @@ -1462,8 +1568,10 @@ class BlockSchemaFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `BlockSchema.version`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.block_type_id is not None: filters.append(self.block_type_id.as_sql_filter()) @@ -1487,45 +1595,51 @@ class BlockDocumentFilterIsAnonymous(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.eq_ is not None: - filters.append(orm_models.BlockDocument.is_anonymous.is_(self.eq_)) + filters.append(db.BlockDocument.is_anonymous.is_(self.eq_)) return filters class BlockDocumentFilterBlockTypeId(PrefectFilterBaseModel): """Filter by `BlockDocument.block_type_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of block type ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.BlockDocument.block_type_id.in_(self.any_)) + filters.append(db.BlockDocument.block_type_id.in_(self.any_)) return filters class BlockDocumentFilterId(PrefectFilterBaseModel): """Filter by `BlockDocument.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of block ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.BlockDocument.id.in_(self.any_)) + filters.append(db.BlockDocument.id.in_(self.any_)) return filters class BlockDocumentFilterName(PrefectFilterBaseModel): """Filter by `BlockDocument.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of block names to include" ) like_: Optional[str] = Field( @@ -1537,12 +1651,14 @@ class BlockDocumentFilterName(PrefectFilterBaseModel): examples=["my-block%"], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.BlockDocument.name.in_(self.any_)) + filters.append(db.BlockDocument.name.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.BlockDocument.name.ilike(f"%{self.like_}%")) + filters.append(db.BlockDocument.name.ilike(f"%{self.like_}%")) return filters @@ -1567,8 +1683,10 @@ class BlockDocumentFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `BlockDocument.name`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) if self.is_anonymous is not None: @@ -1590,10 +1708,12 @@ class FlowRunNotificationPolicyFilterIsActive(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.eq_ is not None: - filters.append(orm_models.FlowRunNotificationPolicy.is_active.is_(self.eq_)) + filters.append(db.FlowRunNotificationPolicy.is_active.is_(self.eq_)) return filters @@ -1605,8 +1725,10 @@ class FlowRunNotificationPolicyFilter(PrefectFilterBaseModel): description="Filter criteria for `FlowRunNotificationPolicy.is_active`. ", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.is_active is not None: filters.append(self.is_active.as_sql_filter()) @@ -1616,28 +1738,30 @@ def _get_filter_list(self) -> List: class WorkQueueFilterId(PrefectFilterBaseModel): """Filter by `WorkQueue.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of work queue ids to include", ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.WorkQueue.id.in_(self.any_)) + filters.append(db.WorkQueue.id.in_(self.any_)) return filters class WorkQueueFilterName(PrefectFilterBaseModel): """Filter by `WorkQueue.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of work queue names to include", examples=[["wq-1", "wq-2"]], ) - startswith_: Optional[List[str]] = Field( + startswith_: Optional[list[str]] = Field( default=None, description=( "A list of case-insensitive starts-with matches. For example, " @@ -1647,17 +1771,16 @@ class WorkQueueFilterName(PrefectFilterBaseModel): examples=[["marvin", "Marvin-robot"]], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.WorkQueue.name.in_(self.any_)) + filters.append(db.WorkQueue.name.in_(self.any_)) if self.startswith_ is not None: filters.append( sa.or_( - *[ - orm_models.WorkQueue.name.ilike(f"{item}%") - for item in self.startswith_ - ] + *[db.WorkQueue.name.ilike(f"{item}%") for item in self.startswith_] ) ) return filters @@ -1675,8 +1798,10 @@ class WorkQueueFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `WorkQueue.name`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) @@ -1689,42 +1814,48 @@ def _get_filter_list(self) -> List: class WorkPoolFilterId(PrefectFilterBaseModel): """Filter by `WorkPool.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of work pool ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.WorkPool.id.in_(self.any_)) + filters.append(db.WorkPool.id.in_(self.any_)) return filters class WorkPoolFilterName(PrefectFilterBaseModel): """Filter by `WorkPool.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of work pool names to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.WorkPool.name.in_(self.any_)) + filters.append(db.WorkPool.name.in_(self.any_)) return filters class WorkPoolFilterType(PrefectFilterBaseModel): """Filter by `WorkPool.type`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of work pool types to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.WorkPool.type.in_(self.any_)) + filters.append(db.WorkPool.type.in_(self.any_)) return filters @@ -1741,8 +1872,10 @@ class WorkPoolFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `WorkPool.type`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) @@ -1757,33 +1890,37 @@ def _get_filter_list(self) -> List: class WorkerFilterWorkPoolId(PrefectFilterBaseModel): """Filter by `Worker.worker_config_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of work pool ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Worker.worker_config_id.in_(self.any_)) + filters.append(db.Worker.work_pool_id.in_(self.any_)) return filters class WorkerFilterStatus(PrefectFilterBaseModel): """Filter by `Worker.status`.""" - any_: Optional[List[schemas.statuses.WorkerStatus]] = Field( + any_: Optional[list[schemas.statuses.WorkerStatus]] = Field( default=None, description="A list of worker statuses to include" ) - not_any_: Optional[List[schemas.statuses.WorkerStatus]] = Field( + not_any_: Optional[list[schemas.statuses.WorkerStatus]] = Field( default=None, description="A list of worker statuses to exclude" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Worker.status.in_(self.any_)) + filters.append(db.Worker.status.in_(self.any_)) if self.not_any_ is not None: - filters.append(orm_models.Worker.status.notin_(self.not_any_)) + filters.append(db.Worker.status.notin_(self.not_any_)) return filters @@ -1803,12 +1940,14 @@ class WorkerFilterLastHeartbeatTime(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.before_ is not None: - filters.append(orm_models.Worker.last_heartbeat_time <= self.before_) + filters.append(db.Worker.last_heartbeat_time <= self.before_) if self.after_ is not None: - filters.append(orm_models.Worker.last_heartbeat_time >= self.after_) + filters.append(db.Worker.last_heartbeat_time >= self.after_) return filters @@ -1828,8 +1967,10 @@ class WorkerFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `Worker.status`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.last_heartbeat_time is not None: filters.append(self.last_heartbeat_time.as_sql_filter()) @@ -1843,21 +1984,23 @@ def _get_filter_list(self) -> List: class ArtifactFilterId(PrefectFilterBaseModel): """Filter by `Artifact.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of artifact ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Artifact.id.in_(self.any_)) + filters.append(db.Artifact.id.in_(self.any_)) return filters class ArtifactFilterKey(PrefectFilterBaseModel): """Filter by `Artifact.key`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of artifact keys to include" ) @@ -1878,17 +2021,19 @@ class ArtifactFilterKey(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Artifact.key.in_(self.any_)) + filters.append(db.Artifact.key.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.Artifact.key.ilike(f"%{self.like_}%")) + filters.append(db.Artifact.key.ilike(f"%{self.like_}%")) if self.exists_ is not None: filters.append( - orm_models.Artifact.key.isnot(None) + db.Artifact.key.isnot(None) if self.exists_ - else orm_models.Artifact.key.is_(None) + else db.Artifact.key.is_(None) ) return filters @@ -1896,47 +2041,53 @@ def _get_filter_list(self) -> List: class ArtifactFilterFlowRunId(PrefectFilterBaseModel): """Filter by `Artifact.flow_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of flow run IDs to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Artifact.flow_run_id.in_(self.any_)) + filters.append(db.Artifact.flow_run_id.in_(self.any_)) return filters class ArtifactFilterTaskRunId(PrefectFilterBaseModel): """Filter by `Artifact.task_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of task run IDs to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Artifact.task_run_id.in_(self.any_)) + filters.append(db.Artifact.task_run_id.in_(self.any_)) return filters class ArtifactFilterType(PrefectFilterBaseModel): """Filter by `Artifact.type`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of artifact types to include" ) - not_any_: Optional[List[str]] = Field( + not_any_: Optional[list[str]] = Field( default=None, description="A list of artifact types to exclude" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Artifact.type.in_(self.any_)) + filters.append(db.Artifact.type.in_(self.any_)) if self.not_any_ is not None: - filters.append(orm_models.Artifact.type.notin_(self.not_any_)) + filters.append(db.Artifact.type.notin_(self.not_any_)) return filters @@ -1959,8 +2110,10 @@ class ArtifactFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `Artifact.type`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) @@ -1979,21 +2132,23 @@ def _get_filter_list(self) -> List: class ArtifactCollectionFilterLatestId(PrefectFilterBaseModel): """Filter by `ArtifactCollection.latest_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of artifact ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.ArtifactCollection.latest_id.in_(self.any_)) + filters.append(db.ArtifactCollection.latest_id.in_(self.any_)) return filters class ArtifactCollectionFilterKey(PrefectFilterBaseModel): """Filter by `ArtifactCollection.key`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of artifact keys to include" ) @@ -2015,17 +2170,19 @@ class ArtifactCollectionFilterKey(PrefectFilterBaseModel): ), ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.ArtifactCollection.key.in_(self.any_)) + filters.append(db.ArtifactCollection.key.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.ArtifactCollection.key.ilike(f"%{self.like_}%")) + filters.append(db.ArtifactCollection.key.ilike(f"%{self.like_}%")) if self.exists_ is not None: filters.append( - orm_models.ArtifactCollection.key.isnot(None) + db.ArtifactCollection.key.isnot(None) if self.exists_ - else orm_models.ArtifactCollection.key.is_(None) + else db.ArtifactCollection.key.is_(None) ) return filters @@ -2033,47 +2190,53 @@ def _get_filter_list(self) -> List: class ArtifactCollectionFilterFlowRunId(PrefectFilterBaseModel): """Filter by `ArtifactCollection.flow_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of flow run IDs to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.ArtifactCollection.flow_run_id.in_(self.any_)) + filters.append(db.ArtifactCollection.flow_run_id.in_(self.any_)) return filters class ArtifactCollectionFilterTaskRunId(PrefectFilterBaseModel): """Filter by `ArtifactCollection.task_run_id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of task run IDs to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.ArtifactCollection.task_run_id.in_(self.any_)) + filters.append(db.ArtifactCollection.task_run_id.in_(self.any_)) return filters class ArtifactCollectionFilterType(PrefectFilterBaseModel): """Filter by `ArtifactCollection.type`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of artifact types to include" ) - not_any_: Optional[List[str]] = Field( + not_any_: Optional[list[str]] = Field( default=None, description="A list of artifact types to exclude" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.ArtifactCollection.type.in_(self.any_)) + filters.append(db.ArtifactCollection.type.in_(self.any_)) if self.not_any_ is not None: - filters.append(orm_models.ArtifactCollection.type.notin_(self.not_any_)) + filters.append(db.ArtifactCollection.type.notin_(self.not_any_)) return filters @@ -2096,8 +2259,10 @@ class ArtifactCollectionFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `Artifact.type`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.latest_id is not None: filters.append(self.latest_id.as_sql_filter()) @@ -2116,21 +2281,23 @@ def _get_filter_list(self) -> List: class VariableFilterId(PrefectFilterBaseModel): """Filter by `Variable.id`.""" - any_: Optional[List[UUID]] = Field( + any_: Optional[list[UUID]] = Field( default=None, description="A list of variable ids to include" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Variable.id.in_(self.any_)) + filters.append(db.Variable.id.in_(self.any_)) return filters class VariableFilterName(PrefectFilterBaseModel): """Filter by `Variable.name`.""" - any_: Optional[List[str]] = Field( + any_: Optional[list[str]] = Field( default=None, description="A list of variables names to include" ) like_: Optional[str] = Field( @@ -2142,19 +2309,21 @@ class VariableFilterName(PrefectFilterBaseModel): examples=["my_variable_%"], ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.any_ is not None: - filters.append(orm_models.Variable.name.in_(self.any_)) + filters.append(db.Variable.name.in_(self.any_)) if self.like_ is not None: - filters.append(orm_models.Variable.name.ilike(f"%{self.like_}%")) + filters.append(db.Variable.name.ilike(f"%{self.like_}%")) return filters class VariableFilterTags(PrefectOperatorFilterBaseModel): """Filter by `Variable.tags`.""" - all_: Optional[List[str]] = Field( + all_: Optional[list[str]] = Field( default=None, examples=[["tag-1", "tag-2"]], description=( @@ -2166,15 +2335,15 @@ class VariableFilterTags(PrefectOperatorFilterBaseModel): default=None, description="If true, only include Variables without tags" ) - def _get_filter_list(self) -> list[sa.ColumnElement[bool]]: + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: filters: list[sa.ColumnElement[bool]] = [] if self.all_ is not None: - filters.append(orm_models.Variable.tags.has_all(_as_array(self.all_))) + filters.append(db.Variable.tags.has_all(_as_array(self.all_))) if self.is_null_ is not None: filters.append( - orm_models.Variable.tags == [] - if self.is_null_ - else orm_models.Variable.tags != [] + db.Variable.tags == [] if self.is_null_ else db.Variable.tags != [] ) return filters @@ -2192,8 +2361,10 @@ class VariableFilter(PrefectOperatorFilterBaseModel): default=None, description="Filter criteria for `Variable.tags`" ) - def _get_filter_list(self) -> List: - filters = [] + def _get_filter_list( + self, db: PrefectDBInterface + ) -> Iterable[sa.ColumnExpressionArgument[bool]]: + filters: list[sa.ColumnExpressionArgument[bool]] = [] if self.id is not None: filters.append(self.id.as_sql_filter()) diff --git a/src/prefect/server/schemas/graph.py b/src/prefect/server/schemas/graph.py index 0d72ceff8a6a6..b49007d6d5bb6 100644 --- a/src/prefect/server/schemas/graph.py +++ b/src/prefect/server/schemas/graph.py @@ -1,23 +1,29 @@ -from datetime import datetime -from typing import Any, List, Literal, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple from uuid import UUID +import pendulum + from prefect.server.schemas.states import StateType from prefect.server.utilities.schemas import PrefectBaseModel +if TYPE_CHECKING: + DateTime = pendulum.DateTime +else: + from prefect.types import DateTime + class GraphState(PrefectBaseModel): id: UUID - timestamp: datetime + timestamp: DateTime type: StateType name: str class GraphArtifact(PrefectBaseModel): id: UUID - created: datetime + created: DateTime key: Optional[str] - type: str + type: Optional[str] is_latest: bool data: Optional[Any] # we only return data for progress artifacts for now @@ -31,8 +37,8 @@ class Node(PrefectBaseModel): id: UUID label: str state_type: StateType - start_time: datetime - end_time: Optional[datetime] + start_time: DateTime + end_time: Optional[DateTime] parents: List[Edge] children: List[Edge] encapsulating: List[Edge] @@ -40,8 +46,8 @@ class Node(PrefectBaseModel): class Graph(PrefectBaseModel): - start_time: datetime - end_time: Optional[datetime] + start_time: Optional[DateTime] + end_time: Optional[DateTime] root_node_ids: List[UUID] nodes: List[Tuple[UUID, Node]] artifacts: List[GraphArtifact] diff --git a/src/prefect/server/schemas/sorting.py b/src/prefect/server/schemas/sorting.py index b7fc7f4a552f7..d2249f61fa821 100644 --- a/src/prefect/server/schemas/sorting.py +++ b/src/prefect/server/schemas/sorting.py @@ -2,16 +2,14 @@ Schemas for sorting Prefect REST API objects. """ -from typing import TYPE_CHECKING +from collections.abc import Iterable +from typing import Any import sqlalchemy as sa -from prefect.server.database import orm_models +from prefect.server.database import PrefectDBInterface, db_injector from prefect.utilities.collections import AutoEnum -if TYPE_CHECKING: - from sqlalchemy.sql.expression import ColumnElement - # TODO: Consider moving the `as_sql_sort` functions out of here since they are a # database model level function and do not properly separate concerns when # present in the schemas module @@ -30,24 +28,29 @@ class FlowRunSort(AutoEnum): NEXT_SCHEDULED_START_TIME_ASC = AutoEnum.auto() END_TIME_DESC = AutoEnum.auto() - def as_sql_sort(self) -> "ColumnElement": - from sqlalchemy.sql.functions import coalesce - + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: """Return an expression used to sort flow runs""" - sort_mapping = { - "ID_DESC": orm_models.FlowRun.id.desc(), - "START_TIME_ASC": coalesce( - orm_models.FlowRun.start_time, orm_models.FlowRun.expected_start_time - ).asc(), - "START_TIME_DESC": coalesce( - orm_models.FlowRun.start_time, orm_models.FlowRun.expected_start_time - ).desc(), - "EXPECTED_START_TIME_ASC": orm_models.FlowRun.expected_start_time.asc(), - "EXPECTED_START_TIME_DESC": orm_models.FlowRun.expected_start_time.desc(), - "NAME_ASC": orm_models.FlowRun.name.asc(), - "NAME_DESC": orm_models.FlowRun.name.desc(), - "NEXT_SCHEDULED_START_TIME_ASC": orm_models.FlowRun.next_scheduled_start_time.asc(), - "END_TIME_DESC": orm_models.FlowRun.end_time.desc(), + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "ID_DESC": [db.FlowRun.id.desc()], + "START_TIME_ASC": [ + sa.func.coalesce( + db.FlowRun.start_time, db.FlowRun.expected_start_time + ).asc() + ], + "START_TIME_DESC": [ + sa.func.coalesce( + db.FlowRun.start_time, db.FlowRun.expected_start_time + ).desc() + ], + "EXPECTED_START_TIME_ASC": [db.FlowRun.expected_start_time.asc()], + "EXPECTED_START_TIME_DESC": [db.FlowRun.expected_start_time.desc()], + "NAME_ASC": [db.FlowRun.name.asc()], + "NAME_DESC": [db.FlowRun.name.desc()], + "NEXT_SCHEDULED_START_TIME_ASC": [ + db.FlowRun.next_scheduled_start_time.asc() + ], + "END_TIME_DESC": [db.FlowRun.end_time.desc()], } return sort_mapping[self.value] @@ -63,16 +66,19 @@ class TaskRunSort(AutoEnum): NEXT_SCHEDULED_START_TIME_ASC = AutoEnum.auto() END_TIME_DESC = AutoEnum.auto() - def as_sql_sort(self) -> "ColumnElement": + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: """Return an expression used to sort task runs""" - sort_mapping = { - "ID_DESC": orm_models.TaskRun.id.desc(), - "EXPECTED_START_TIME_ASC": orm_models.TaskRun.expected_start_time.asc(), - "EXPECTED_START_TIME_DESC": orm_models.TaskRun.expected_start_time.desc(), - "NAME_ASC": orm_models.TaskRun.name.asc(), - "NAME_DESC": orm_models.TaskRun.name.desc(), - "NEXT_SCHEDULED_START_TIME_ASC": orm_models.TaskRun.next_scheduled_start_time.asc(), - "END_TIME_DESC": orm_models.TaskRun.end_time.desc(), + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "ID_DESC": [db.TaskRun.id.desc()], + "EXPECTED_START_TIME_ASC": [db.TaskRun.expected_start_time.asc()], + "EXPECTED_START_TIME_DESC": [db.TaskRun.expected_start_time.desc()], + "NAME_ASC": [db.TaskRun.name.asc()], + "NAME_DESC": [db.TaskRun.name.desc()], + "NEXT_SCHEDULED_START_TIME_ASC": [ + db.TaskRun.next_scheduled_start_time.asc() + ], + "END_TIME_DESC": [db.TaskRun.end_time.desc()], } return sort_mapping[self.value] @@ -83,11 +89,12 @@ class LogSort(AutoEnum): TIMESTAMP_ASC = AutoEnum.auto() TIMESTAMP_DESC = AutoEnum.auto() - def as_sql_sort(self) -> "ColumnElement": + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: """Return an expression used to sort task runs""" - sort_mapping = { - "TIMESTAMP_ASC": orm_models.Log.timestamp.asc(), - "TIMESTAMP_DESC": orm_models.Log.timestamp.desc(), + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "TIMESTAMP_ASC": [db.Log.timestamp.asc()], + "TIMESTAMP_DESC": [db.Log.timestamp.desc()], } return sort_mapping[self.value] @@ -100,13 +107,14 @@ class FlowSort(AutoEnum): NAME_ASC = AutoEnum.auto() NAME_DESC = AutoEnum.auto() - def as_sql_sort(self) -> "ColumnElement": - """Return an expression used to sort flows""" - sort_mapping = { - "CREATED_DESC": orm_models.Flow.created.desc(), - "UPDATED_DESC": orm_models.Flow.updated.desc(), - "NAME_ASC": orm_models.Flow.name.asc(), - "NAME_DESC": orm_models.Flow.name.desc(), + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: + """Return an expression used to sort task runs""" + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "CREATED_DESC": [db.Flow.created.desc()], + "UPDATED_DESC": [db.Flow.updated.desc()], + "NAME_ASC": [db.Flow.name.asc()], + "NAME_DESC": [db.Flow.name.desc()], } return sort_mapping[self.value] @@ -119,13 +127,14 @@ class DeploymentSort(AutoEnum): NAME_ASC = AutoEnum.auto() NAME_DESC = AutoEnum.auto() - def as_sql_sort(self) -> "ColumnElement": - """Return an expression used to sort deployments""" - sort_mapping = { - "CREATED_DESC": orm_models.Deployment.created.desc(), - "UPDATED_DESC": orm_models.Deployment.updated.desc(), - "NAME_ASC": orm_models.Deployment.name.asc(), - "NAME_DESC": orm_models.Deployment.name.desc(), + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: + """Return an expression used to sort task runs""" + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "CREATED_DESC": [db.Deployment.created.desc()], + "UPDATED_DESC": [db.Deployment.updated.desc()], + "NAME_ASC": [db.Deployment.name.asc()], + "NAME_DESC": [db.Deployment.name.desc()], } return sort_mapping[self.value] @@ -139,14 +148,15 @@ class ArtifactSort(AutoEnum): KEY_DESC = AutoEnum.auto() KEY_ASC = AutoEnum.auto() - def as_sql_sort(self) -> "ColumnElement": - """Return an expression used to sort artifacts""" - sort_mapping = { - "CREATED_DESC": orm_models.Artifact.created.desc(), - "UPDATED_DESC": orm_models.Artifact.updated.desc(), - "ID_DESC": orm_models.Artifact.id.desc(), - "KEY_DESC": orm_models.Artifact.key.desc(), - "KEY_ASC": orm_models.Artifact.key.asc(), + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: + """Return an expression used to sort task runs""" + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "CREATED_DESC": [db.Artifact.created.desc()], + "UPDATED_DESC": [db.Artifact.updated.desc()], + "ID_DESC": [db.Artifact.id.desc()], + "KEY_DESC": [db.Artifact.key.desc()], + "KEY_ASC": [db.Artifact.key.asc()], } return sort_mapping[self.value] @@ -160,14 +170,15 @@ class ArtifactCollectionSort(AutoEnum): KEY_DESC = AutoEnum.auto() KEY_ASC = AutoEnum.auto() - def as_sql_sort(self) -> "ColumnElement": - """Return an expression used to sort artifact collections""" - sort_mapping = { - "CREATED_DESC": orm_models.ArtifactCollection.created.desc(), - "UPDATED_DESC": orm_models.ArtifactCollection.updated.desc(), - "ID_DESC": orm_models.ArtifactCollection.id.desc(), - "KEY_DESC": orm_models.ArtifactCollection.key.desc(), - "KEY_ASC": orm_models.ArtifactCollection.key.asc(), + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: + """Return an expression used to sort task runs""" + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "CREATED_DESC": [db.ArtifactCollection.created.desc()], + "UPDATED_DESC": [db.ArtifactCollection.updated.desc()], + "ID_DESC": [db.ArtifactCollection.id.desc()], + "KEY_DESC": [db.ArtifactCollection.key.desc()], + "KEY_ASC": [db.ArtifactCollection.key.asc()], } return sort_mapping[self.value] @@ -180,13 +191,14 @@ class VariableSort(AutoEnum): NAME_DESC = "NAME_DESC" NAME_ASC = "NAME_ASC" - def as_sql_sort(self) -> "ColumnElement": - """Return an expression used to sort variables""" - sort_mapping = { - "CREATED_DESC": orm_models.Variable.created.desc(), - "UPDATED_DESC": orm_models.Variable.updated.desc(), - "NAME_DESC": orm_models.Variable.name.desc(), - "NAME_ASC": orm_models.Variable.name.asc(), + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: + """Return an expression used to sort task runs""" + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "CREATED_DESC": [db.Variable.created.desc()], + "UPDATED_DESC": [db.Variable.updated.desc()], + "NAME_DESC": [db.Variable.name.desc()], + "NAME_ASC": [db.Variable.name.asc()], } return sort_mapping[self.value] @@ -198,11 +210,15 @@ class BlockDocumentSort(AutoEnum): NAME_ASC = "NAME_ASC" BLOCK_TYPE_AND_NAME_ASC = "BLOCK_TYPE_AND_NAME_ASC" - def as_sql_sort(self) -> "ColumnElement": - """Return an expression used to sort block documents""" - sort_mapping = { - "NAME_DESC": orm_models.BlockDocument.name.desc(), - "NAME_ASC": orm_models.BlockDocument.name.asc(), - "BLOCK_TYPE_AND_NAME_ASC": sa.text("block_type_name asc, name asc"), + @db_injector + def as_sql_sort(self, db: PrefectDBInterface) -> Iterable[sa.ColumnElement[Any]]: + """Return an expression used to sort task runs""" + sort_mapping: dict[str, Iterable[sa.ColumnElement[Any]]] = { + "NAME_DESC": [db.BlockDocument.name.desc()], + "NAME_ASC": [db.BlockDocument.name.asc()], + "BLOCK_TYPE_AND_NAME_ASC": [ + db.BlockDocument.block_type_name.asc(), + db.BlockDocument.name.asc(), + ], } return sort_mapping[self.value] diff --git a/src/prefect/server/schemas/states.py b/src/prefect/server/schemas/states.py index da41805b19566..56b8538a35ec7 100644 --- a/src/prefect/server/schemas/states.py +++ b/src/prefect/server/schemas/states.py @@ -2,27 +2,38 @@ State schemas. """ -import datetime import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Optional, + TypeVar, + Union, + overload, +) from uuid import UUID, uuid4 import pendulum -from pydantic import ConfigDict, Field, field_validator, model_validator +from pydantic import ConfigDict, Field, ValidationInfo, field_validator, model_validator from typing_extensions import Self -from prefect.server.utilities.schemas.bases import ( - IDBaseModel, - PrefectBaseModel, -) -from prefect.types import DateTime +from prefect.client.schemas import objects +from prefect.server.utilities.schemas.bases import IDBaseModel, PrefectBaseModel from prefect.utilities.collections import AutoEnum if TYPE_CHECKING: from prefect.server.database.orm_models import ORMFlowRunState, ORMTaskRunState from prefect.server.schemas.actions import StateCreate + DateTime = pendulum.DateTime +else: + from prefect.types import DateTime + + R = TypeVar("R") +_State = TypeVar("_State", bound="State") class StateType(AutoEnum): @@ -50,13 +61,11 @@ class CountByState(PrefectBaseModel): CANCELLING: int = Field(default=0) SCHEDULED: int = Field(default=0) - # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. @field_validator("*") @classmethod - def check_key(cls, value, info): - if info.name not in StateType.__members__: - raise ValueError(f"{info.name} is not a valid StateType") + def check_key(cls, value: Optional[Any], info: ValidationInfo): + if info.field_name not in StateType.__members__: + raise ValueError(f"{info.field_name} is not a valid StateType") return value @@ -81,7 +90,7 @@ class StateDetails(PrefectBaseModel): pause_timeout: Optional[DateTime] = None pause_reschedule: bool = False pause_key: Optional[str] = None - run_input_keyset: Optional[Dict[str, str]] = None + run_input_keyset: Optional[dict[str, str]] = None refresh_cache: Optional[bool] = None retriable: Optional[bool] = None transition_id: Optional[UUID] = None @@ -89,7 +98,7 @@ class StateDetails(PrefectBaseModel): class StateBaseModel(IDBaseModel): - def orm_dict(self, *args, **kwargs) -> dict: + def orm_dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """ This method is used as a convenience method for constructing fixtues by first building a `State` schema object and converting it into an ORM-compatible @@ -107,7 +116,7 @@ def orm_dict(self, *args, **kwargs) -> dict: class State(StateBaseModel): """Represents the state of a run.""" - model_config = ConfigDict(from_attributes=True) + model_config: ClassVar[ConfigDict] = ConfigDict(from_attributes=True) type: StateType name: Optional[str] = Field(default=None) @@ -127,7 +136,7 @@ def from_orm_without_result( cls, orm_state: Union["ORMFlowRunState", "ORMTaskRunState"], with_data: Optional[Any] = None, - ): + ) -> Self: """ During orchestration, ORM states can be instantiated prior to inserting results into the artifact table and the `data` field will not be eagerly loaded. In @@ -140,7 +149,7 @@ def from_orm_without_result( """ field_keys = cls.model_json_schema()["properties"].keys() - state_data = { + state_data: dict[str, Any] = { field: getattr(orm_state, field, None) for field in field_keys if field != "data" @@ -149,7 +158,7 @@ def from_orm_without_result( return cls(**state_data) @model_validator(mode="after") - def default_name_from_type(self): + def default_name_from_type(self) -> Self: """If a name is not provided, use the type""" # if `type` is not in `values` it means the `type` didn't pass its own # validation check and an error will be raised after this function is called @@ -159,7 +168,7 @@ def default_name_from_type(self): return self @model_validator(mode="after") - def default_scheduled_start_time(self): + def default_scheduled_start_time(self) -> Self: from prefect.server.schemas.states import StateType if self.type == StateType.SCHEDULED: @@ -198,7 +207,7 @@ def is_final(self) -> bool: def is_paused(self) -> bool: return self.type == StateType.PAUSED - def fresh_copy(self, **kwargs) -> Self: + def fresh_copy(self, **kwargs: Any) -> Self: """ Return a fresh copy of the state with a new ID. """ @@ -213,7 +222,25 @@ def fresh_copy(self, **kwargs) -> Self: **kwargs, ) - def result(self, raise_on_failure: bool = True, fetch: Optional[bool] = None): + @overload + def result(self, raise_on_failure: Literal[True] = ..., fetch: bool = ...) -> Any: + ... + + @overload + def result( + self, raise_on_failure: Literal[False] = False, fetch: bool = ... + ) -> Union[Any, Exception]: + ... + + @overload + def result( + self, raise_on_failure: bool = ..., fetch: bool = ... + ) -> Union[Any, Exception]: + ... + + def result( + self, raise_on_failure: bool = True, fetch: bool = True + ) -> Union[Any, Exception]: # Backwards compatible `result` handling on the server-side schema from prefect.states import State @@ -227,7 +254,7 @@ def result(self, raise_on_failure: bool = True, fetch: Optional[bool] = None): stacklevel=2, ) - state = State.model_validate(self) + state: State[Any] = objects.State.model_validate(self) return state.result(raise_on_failure=raise_on_failure, fetch=fetch) def to_state_create(self) -> "StateCreate": @@ -265,12 +292,12 @@ def __str__(self) -> str: `MyCompletedState("my message", type=COMPLETED)` """ - display = [] + display: list[str] = [] if self.message: display.append(repr(self.message)) - if self.type.value.lower() != self.name.lower(): + if self.type.value.lower() != (self.name or "").lower(): display.append(f"type={self.type.value}") return f"{self.name}({', '.join(display)})" @@ -287,10 +314,10 @@ def __hash__(self) -> int: def Scheduled( - scheduled_time: Optional[datetime.datetime] = None, - cls: Type[State] = State, - **kwargs, -) -> State: + scheduled_time: Optional[pendulum.DateTime] = None, + cls: type[_State] = State, + **kwargs: Any, +) -> _State: """Convenience function for creating `Scheduled` states. Returns: @@ -308,7 +335,7 @@ def Scheduled( return cls(type=StateType.SCHEDULED, state_details=state_details, **kwargs) -def Completed(cls: Type[State] = State, **kwargs) -> State: +def Completed(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Completed` states. Returns: @@ -317,7 +344,7 @@ def Completed(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.COMPLETED, **kwargs) -def Running(cls: Type[State] = State, **kwargs) -> State: +def Running(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Running` states. Returns: @@ -326,7 +353,7 @@ def Running(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.RUNNING, **kwargs) -def Failed(cls: Type[State] = State, **kwargs) -> State: +def Failed(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Failed` states. Returns: @@ -335,7 +362,7 @@ def Failed(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.FAILED, **kwargs) -def Crashed(cls: Type[State] = State, **kwargs) -> State: +def Crashed(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Crashed` states. Returns: @@ -344,7 +371,7 @@ def Crashed(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.CRASHED, **kwargs) -def Cancelling(cls: Type[State] = State, **kwargs) -> State: +def Cancelling(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Cancelling` states. Returns: @@ -353,7 +380,7 @@ def Cancelling(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.CANCELLING, **kwargs) -def Cancelled(cls: Type[State] = State, **kwargs) -> State: +def Cancelled(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Cancelled` states. Returns: @@ -362,7 +389,7 @@ def Cancelled(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.CANCELLED, **kwargs) -def Pending(cls: Type[State] = State, **kwargs) -> State: +def Pending(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Pending` states. Returns: @@ -372,13 +399,13 @@ def Pending(cls: Type[State] = State, **kwargs) -> State: def Paused( - cls: Type[State] = State, + cls: type[_State] = State, timeout_seconds: Optional[int] = None, - pause_expiration_time: Optional[datetime.datetime] = None, - reschedule: Optional[bool] = False, + pause_expiration_time: Optional[pendulum.DateTime] = None, + reschedule: bool = False, pause_key: Optional[str] = None, - **kwargs, -) -> State: + **kwargs: Any, +) -> _State: """Convenience function for creating `Paused` states. Returns: @@ -394,11 +421,11 @@ def Paused( "Cannot supply both a pause_expiration_time and timeout_seconds" ) - if pause_expiration_time is None and timeout_seconds is None: - pass - else: - state_details.pause_timeout = pause_expiration_time or ( - pendulum.now("UTC") + pendulum.Duration(seconds=timeout_seconds) + if pause_expiration_time: + state_details.pause_timeout = pause_expiration_time + elif timeout_seconds is not None: + state_details.pause_timeout = pendulum.now("UTC") + pendulum.Duration( + seconds=timeout_seconds ) state_details.pause_reschedule = reschedule @@ -408,12 +435,12 @@ def Paused( def Suspended( - cls: Type[State] = State, + cls: type[_State] = State, timeout_seconds: Optional[int] = None, - pause_expiration_time: Optional[datetime.datetime] = None, + pause_expiration_time: Optional[pendulum.DateTime] = None, pause_key: Optional[str] = None, - **kwargs, -): + **kwargs: Any, +) -> _State: """Convenience function for creating `Suspended` states. Returns: @@ -431,8 +458,10 @@ def Suspended( def AwaitingRetry( - scheduled_time: datetime.datetime = None, cls: Type[State] = State, **kwargs -) -> State: + cls: type[_State] = State, + scheduled_time: Optional[pendulum.DateTime] = None, + **kwargs: Any, +) -> _State: """Convenience function for creating `AwaitingRetry` states. Returns: @@ -443,7 +472,7 @@ def AwaitingRetry( ) -def Retrying(cls: Type[State] = State, **kwargs) -> State: +def Retrying(cls: type[_State] = State, **kwargs: Any) -> _State: """Convenience function for creating `Retrying` states. Returns: @@ -453,8 +482,10 @@ def Retrying(cls: Type[State] = State, **kwargs) -> State: def Late( - scheduled_time: datetime.datetime = None, cls: Type[State] = State, **kwargs -) -> State: + cls: type[_State] = State, + scheduled_time: Optional[pendulum.DateTime] = None, + **kwargs: Any, +) -> _State: """Convenience function for creating `Late` states. Returns: diff --git a/src/prefect/server/services/cancellation_cleanup.py b/src/prefect/server/services/cancellation_cleanup.py index d2ae5bdcbb7de..221201d57a8bb 100644 --- a/src/prefect/server/services/cancellation_cleanup.py +++ b/src/prefect/server/services/cancellation_cleanup.py @@ -11,9 +11,7 @@ from sqlalchemy.sql.expression import or_ import prefect.server.models as models -from prefect.server.database import orm_models -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db, orm_models from prefect.server.schemas import filters, states from prefect.server.services.loop_service import LoopService from prefect.settings import PREFECT_API_SERVICES_CANCELLATION_CLEANUP_LOOP_SECONDS @@ -51,15 +49,14 @@ async def run_once(self, db: PrefectDBInterface): self.logger.info("Finished cleaning up cancelled flow runs.") - async def clean_up_cancelled_flow_run_task_runs(self, db): + async def clean_up_cancelled_flow_run_task_runs(self, db: PrefectDBInterface): while True: cancelled_flow_query = ( - sa.select(orm_models.FlowRun) + sa.select(db.FlowRun) .where( - orm_models.FlowRun.state_type == states.StateType.CANCELLED, - orm_models.FlowRun.end_time.is_not(None), - orm_models.FlowRun.end_time - >= (pendulum.now("UTC").subtract(days=1)), + db.FlowRun.state_type == states.StateType.CANCELLED, + db.FlowRun.end_time.is_not(None), + db.FlowRun.end_time >= (pendulum.now("UTC").subtract(days=1)), ) .limit(self.batch_size) ) @@ -75,23 +72,23 @@ async def clean_up_cancelled_flow_run_task_runs(self, db): if len(flow_runs) < self.batch_size: break - async def clean_up_cancelled_subflow_runs(self, db): + async def clean_up_cancelled_subflow_runs(self, db: PrefectDBInterface): high_water_mark = UUID(int=0) while True: subflow_query = ( - sa.select(orm_models.FlowRun) + sa.select(db.FlowRun) .where( or_( - orm_models.FlowRun.state_type == states.StateType.PENDING, - orm_models.FlowRun.state_type == states.StateType.SCHEDULED, - orm_models.FlowRun.state_type == states.StateType.RUNNING, - orm_models.FlowRun.state_type == states.StateType.PAUSED, - orm_models.FlowRun.state_type == states.StateType.CANCELLING, + db.FlowRun.state_type == states.StateType.PENDING, + db.FlowRun.state_type == states.StateType.SCHEDULED, + db.FlowRun.state_type == states.StateType.RUNNING, + db.FlowRun.state_type == states.StateType.PAUSED, + db.FlowRun.state_type == states.StateType.CANCELLING, ), - orm_models.FlowRun.id > high_water_mark, - orm_models.FlowRun.parent_task_run_id.is_not(None), + db.FlowRun.id > high_water_mark, + db.FlowRun.parent_task_run_id.is_not(None), ) - .order_by(orm_models.FlowRun.id) + .order_by(db.FlowRun.id) .limit(self.batch_size) ) diff --git a/src/prefect/server/services/flow_run_notifications.py b/src/prefect/server/services/flow_run_notifications.py index e1e658f3110e5..132c81c4c9faf 100644 --- a/src/prefect/server/services/flow_run_notifications.py +++ b/src/prefect/server/services/flow_run_notifications.py @@ -8,8 +8,7 @@ import sqlalchemy as sa from prefect.server import models, schemas -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.server.services.loop_service import LoopService from prefect.utilities import urls @@ -36,9 +35,8 @@ async def run_once(self, db: PrefectDBInterface): # notifications that we pulled here. If we drain in batches larger # than 1, we risk double-sending earlier notifications when a # transient error occurs. - notifications = await db.get_flow_run_notifications_from_queue( - session=session, - limit=1, + notifications = await db.queries.get_flow_run_notifications_from_queue( + session=session, limit=1 ) self.logger.debug(f"Got {len(notifications)} notifications from queue.") diff --git a/src/prefect/server/services/foreman.py b/src/prefect/server/services/foreman.py index ae4736da9356d..5d8b79b033bdb 100644 --- a/src/prefect/server/services/foreman.py +++ b/src/prefect/server/services/foreman.py @@ -9,13 +9,16 @@ import sqlalchemy as sa from prefect.server import models -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.models.deployments import mark_deployments_not_ready from prefect.server.models.work_queues import mark_work_queues_not_ready from prefect.server.models.workers import emit_work_pool_status_event from prefect.server.schemas.internal import InternalWorkPoolUpdate -from prefect.server.schemas.statuses import DeploymentStatus, WorkPoolStatus +from prefect.server.schemas.statuses import ( + DeploymentStatus, + WorkerStatus, + WorkPoolStatus, +) from prefect.server.services.loop_service import LoopService from prefect.settings import ( PREFECT_API_SERVICES_FOREMAN_DEPLOYMENT_LAST_POLLED_TIMEOUT_SECONDS, @@ -98,40 +101,21 @@ async def _mark_online_workers_without_a_recent_heartbeat_as_offline( session (AsyncSession): The session to use for the database operation. """ async with db.session_context(begin_transaction=True) as session: - if db.dialect.name == "postgresql": - worker_update_stmt = sa.text( - """ - UPDATE worker - SET status = 'OFFLINE' - WHERE ( - last_heartbeat_time < - CURRENT_TIMESTAMP - ( - interval '1 second' * :multiplier * - COALESCE(heartbeat_interval_seconds, :default_interval) + worker_update_stmt = ( + sa.update(db.Worker) + .values(status=WorkerStatus.OFFLINE) + .where( + sa.func.date_diff_seconds(db.Worker.last_heartbeat_time) + > ( + sa.func.coalesce( + db.Worker.heartbeat_interval_seconds, + sa.bindparam("default_interval", sa.Integer), ) - ) - AND status = 'ONLINE'; - """ - ) - elif db.dialect.name == "sqlite": - worker_update_stmt = sa.text( - """ - UPDATE worker - SET status = 'OFFLINE' - WHERE ( - julianday(last_heartbeat_time) < - julianday('now') - (( - :multiplier * - COALESCE(heartbeat_interval_seconds, :default_interval) - ) / 86400.0) - ) - AND status = 'ONLINE'; - """ - ) - else: - raise NotImplementedError( - f"No implementation for database dialect {db.dialect.name}" + * sa.bindparam("multiplier", sa.Integer) + ), + db.Worker.status == WorkerStatus.ONLINE, ) + ) result = await session.execute( worker_update_stmt, diff --git a/src/prefect/server/services/late_runs.py b/src/prefect/server/services/late_runs.py index b1c400930475d..dfc2abf6dd83b 100644 --- a/src/prefect/server/services/late_runs.py +++ b/src/prefect/server/services/late_runs.py @@ -12,8 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.models as models -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.server.exceptions import ObjectNotFoundError from prefect.server.orchestration.core_policy import MarkLateRunsPolicy from prefect.server.schemas import states diff --git a/src/prefect/server/services/loop_service.py b/src/prefect/server/services/loop_service.py index 55078ff1a624c..d800c968807ae 100644 --- a/src/prefect/server/services/loop_service.py +++ b/src/prefect/server/services/loop_service.py @@ -10,8 +10,7 @@ import pendulum from prefect.logging import get_logger -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.settings import PREFECT_API_LOG_RETRYABLE_ERRORS from prefect.utilities.processutils import _register_signal diff --git a/src/prefect/server/services/pause_expirations.py b/src/prefect/server/services/pause_expirations.py index a6cf227ce6331..390955bf7f9cc 100644 --- a/src/prefect/server/services/pause_expirations.py +++ b/src/prefect/server/services/pause_expirations.py @@ -10,8 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import prefect.server.models as models -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.server.schemas import states from prefect.server.services.loop_service import LoopService from prefect.settings import PREFECT_API_SERVICES_PAUSE_EXPIRATIONS_LOOP_SECONDS diff --git a/src/prefect/server/services/scheduler.py b/src/prefect/server/services/scheduler.py index f27a14bd64c80..69013ce7dbe5f 100644 --- a/src/prefect/server/services/scheduler.py +++ b/src/prefect/server/services/scheduler.py @@ -11,8 +11,7 @@ import sqlalchemy as sa import prefect.server.models as models -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.server.schemas.states import StateType from prefect.server.services.loop_service import LoopService, run_multiple_services from prefect.settings import ( @@ -265,6 +264,7 @@ async def _generate_scheduled_flow_runs( """ return await models.deployments._generate_scheduled_flow_runs( + db, session=session, deployment_id=deployment_id, start_time=start_time, diff --git a/src/prefect/server/services/task_run_recorder.py b/src/prefect/server/services/task_run_recorder.py index ab19dcfa8c411..ed51953132a2d 100644 --- a/src/prefect/server/services/task_run_recorder.py +++ b/src/prefect/server/services/task_run_recorder.py @@ -8,8 +8,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.logging import get_logger -from prefect.server.database.dependencies import db_injector, provide_database_interface -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import ( + PrefectDBInterface, + db_injector, + provide_database_interface, +) from prefect.server.events.ordering import CausalOrdering, EventArrivedEarly from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.schemas.core import TaskRun @@ -33,7 +36,7 @@ async def _insert_task_run( task_run_attributes: Dict[str, Any], ): await session.execute( - db.insert(db.TaskRun) + db.queries.insert(db.TaskRun) .values( created=pendulum.now("UTC"), **task_run_attributes, @@ -56,7 +59,7 @@ async def _insert_task_run_state( db: PrefectDBInterface, session: AsyncSession, task_run: TaskRun ): await session.execute( - db.insert(db.TaskRunState) + db.queries.insert(db.TaskRunState) .values( created=pendulum.now("UTC"), task_run_id=task_run.id, diff --git a/src/prefect/server/services/telemetry.py b/src/prefect/server/services/telemetry.py index 3adf82bbd7b40..0fced956a9fcf 100644 --- a/src/prefect/server/services/telemetry.py +++ b/src/prefect/server/services/telemetry.py @@ -12,8 +12,7 @@ import pendulum import prefect -from prefect.server.database.dependencies import inject_db -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, inject_db from prefect.server.models import configuration from prefect.server.schemas.core import Configuration from prefect.server.services.loop_service import LoopService diff --git a/src/prefect/server/utilities/database.py b/src/prefect/server/utilities/database.py index 776d32eb2caa7..80ff10db658e6 100644 --- a/src/prefect/server/utilities/database.py +++ b/src/prefect/server/utilities/database.py @@ -286,6 +286,22 @@ def process_result_value( return pydantic.TypeAdapter(self._pydantic_type).validate_python(value) +def bindparams_from_clause( + query: sa.ClauseElement, +) -> dict[str, sa.BindParameter[Any]]: + """Retrieve all non-anonymous bind parameters defined in a SQL clause""" + # we could use `traverse(query, {}, {"bindparam": some_list.append})` too, + # but this private method builds on the SQLA query caching infrastructure + # and so is more efficient. + return { + bp.key: bp + for bp in query._get_embedded_bindparams() # pyright: ignore[reportPrivateUsage] + # Anonymous keys are always a printf-style template that starts with '%([seed]' + # the seed is the id() of the bind parameter itself. + if not bp.key.startswith(f"%({id(bp)}") + } + + # Platform-independent datetime and timedelta arithmetic functions @@ -648,7 +664,6 @@ def sqlite_json_operators( class greatest(functions.ReturnTypeFromArgs[T]): - name = "greatest" inherit_cache = True diff --git a/tests/events/server/actions/test_pausing_resuming_work_pool.py b/tests/events/server/actions/test_pausing_resuming_work_pool.py index d9cf20cc24713..b2c31b1a57cbf 100644 --- a/tests/events/server/actions/test_pausing_resuming_work_pool.py +++ b/tests/events/server/actions/test_pausing_resuming_work_pool.py @@ -6,7 +6,7 @@ from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.events import actions from prefect.server.events.clients import AssertingEventsClient from prefect.server.events.schemas.automations import ( diff --git a/tests/events/server/conftest.py b/tests/events/server/conftest.py index e66541d695dd5..938545d21709c 100644 --- a/tests/events/server/conftest.py +++ b/tests/events/server/conftest.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.events import ResourceSpecification, actions, messaging from prefect.server.events.schemas.automations import ( Automation, diff --git a/tests/events/server/models/test_automations.py b/tests/events/server/models/test_automations.py index 1101011b21981..4b7464f3b3295 100644 --- a/tests/events/server/models/test_automations.py +++ b/tests/events/server/models/test_automations.py @@ -7,7 +7,7 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.events import actions, filters from prefect.server.events.models import automations from prefect.server.events.schemas.automations import ( diff --git a/tests/events/server/storage/test_database.py b/tests/events/server/storage/test_database.py index 0006d6a64dbea..cfb3c50498249 100644 --- a/tests/events/server/storage/test_database.py +++ b/tests/events/server/storage/test_database.py @@ -6,7 +6,7 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.events.filters import ( EventFilter, EventIDFilter, diff --git a/tests/events/server/storage/test_event_persister.py b/tests/events/server/storage/test_event_persister.py index a7ea13d018bac..53058e59b2a11 100644 --- a/tests/events/server/storage/test_event_persister.py +++ b/tests/events/server/storage/test_event_persister.py @@ -8,8 +8,7 @@ from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.events.filters import EventFilter from prefect.server.events.schemas.events import ReceivedEvent from prefect.server.events.services import event_persister diff --git a/tests/events/server/test_automations_api.py b/tests/events/server/test_automations_api.py index d02cf6df92a40..430f067bc7f19 100644 --- a/tests/events/server/test_automations_api.py +++ b/tests/events/server/test_automations_api.py @@ -16,7 +16,7 @@ from prefect.server import models as server_models from prefect.server import schemas as server_schemas from prefect.server.api.validation import ValidationError -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.events import actions, filters from prefect.server.events.models.automations import ( create_automation, diff --git a/tests/events/server/triggers/test_composite_triggers.py b/tests/events/server/triggers/test_composite_triggers.py index 728bc1fa88f31..cb5d690da3d3c 100644 --- a/tests/events/server/triggers/test_composite_triggers.py +++ b/tests/events/server/triggers/test_composite_triggers.py @@ -7,7 +7,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.events import actions, triggers from prefect.server.events.models import automations from prefect.server.events.schemas.automations import ( diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 6721bedc9647b..6f121f7573df9 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -14,12 +14,12 @@ from prefect.blocks.notifications import NotificationBlock from prefect.filesystems import LocalFileSystem from prefect.server import models, schemas -from prefect.server.database import orm_models -from prefect.server.database.configurations import ENGINES, TRACKER -from prefect.server.database.dependencies import ( +from prefect.server.database import ( PrefectDBInterface, + orm_models, provide_database_interface, ) +from prefect.server.database.configurations import ENGINES, TRACKER from prefect.server.models.block_registration import run_block_auto_registration from prefect.server.models.concurrency_limits_v2 import create_concurrency_limit from prefect.server.orchestration.rules import ( diff --git a/tests/server/api/test_csrf_token.py b/tests/server/api/test_csrf_token.py index 640e590206281..a08b89c23349c 100644 --- a/tests/server/api/test_csrf_token.py +++ b/tests/server/api/test_csrf_token.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import models, schemas -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.settings import PREFECT_SERVER_CSRF_PROTECTION_ENABLED, temporary_settings diff --git a/tests/server/api/test_middleware.py b/tests/server/api/test_middleware.py index 243b4b4eec933..efc28602de6eb 100644 --- a/tests/server/api/test_middleware.py +++ b/tests/server/api/test_middleware.py @@ -9,7 +9,7 @@ from prefect.server import models, schemas from prefect.server.api.middleware import CsrfMiddleware -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.settings import ( PREFECT_SERVER_CSRF_PROTECTION_ENABLED, temporary_settings, diff --git a/tests/server/database/test_dependencies.py b/tests/server/database/test_dependencies.py index 9d8b23f9a19da..22804778e7e86 100644 --- a/tests/server/database/test_dependencies.py +++ b/tests/server/database/test_dependencies.py @@ -1,17 +1,16 @@ import asyncio -import datetime from uuid import UUID +import pendulum import pytest from sqlalchemy.ext.asyncio import AsyncSession -from prefect.server.database import dependencies +from prefect.server.database import PrefectDBInterface, dependencies from prefect.server.database.configurations import ( AioSqliteConfiguration, AsyncPostgresConfiguration, BaseDatabaseConfiguration, ) -from prefect.server.database.interface import PrefectDBInterface from prefect.server.database.orm_models import ( AioSqliteORMConfiguration, AsyncPostgresORMConfiguration, @@ -127,11 +126,14 @@ def get_scheduled_flow_runs_from_work_queues( def _get_scheduled_flow_runs_from_work_pool_template_path(self): ... + def _build_flow_run_graph_v2_query(self): + ... + async def flow_run_graph_v2( self, session: AsyncSession, flow_run_id: UUID, - since: datetime, + since: pendulum.DateTime, max_nodes: int, ) -> Graph: raise NotImplementedError() diff --git a/tests/server/database/test_queries.py b/tests/server/database/test_queries.py index b17938fb637ce..53f0842f4f8c8 100644 --- a/tests/server/database/test_queries.py +++ b/tests/server/database/test_queries.py @@ -3,7 +3,7 @@ import sqlalchemy as sa from prefect.server import models, schemas -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface class TestGetRunsInQueueQuery: diff --git a/tests/server/models/test_concurrency_limits_v2.py b/tests/server/models/test_concurrency_limits_v2.py index 05379c046d40f..ea91a63d71077 100644 --- a/tests/server/models/test_concurrency_limits_v2.py +++ b/tests/server/models/test_concurrency_limits_v2.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import models, schemas -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.models.concurrency_limits_v2 import ( MINIMUM_OCCUPANCY_SECONDS_PER_SLOT, bulk_decrement_active_slots, diff --git a/tests/server/models/test_task_run_states.py b/tests/server/models/test_task_run_states.py index 9854f5216b144..d8f21f7b2426a 100644 --- a/tests/server/models/test_task_run_states.py +++ b/tests/server/models/test_task_run_states.py @@ -272,7 +272,7 @@ async def test_task_run_states_filters_by_task_run_id(self, session): class TestDeleteTaskRunState: - async def test_delete_task_run_state(self, task_run, session): + async def test_delete_task_run_state(self, db, task_run, session): # create a task run to read task_run_state = ( @@ -284,7 +284,7 @@ async def test_delete_task_run_state(self, task_run, session): ).state assert await models.task_run_states.delete_task_run_state( - session=session, task_run_state_id=task_run_state.id + db, session=session, task_run_state_id=task_run_state.id ) # make sure the task run state is deleted @@ -293,8 +293,10 @@ async def test_delete_task_run_state(self, task_run, session): ) assert result is None - async def test_delete_task_run_state_returns_false_if_does_not_exist(self, session): + async def test_delete_task_run_state_returns_false_if_does_not_exist( + self, db, session + ): result = await models.task_run_states.delete_task_run_state( - session=session, task_run_state_id=uuid4() + db, session=session, task_run_state_id=uuid4() ) assert not result diff --git a/tests/server/orchestration/api/test_concurrency_limits_v2.py b/tests/server/orchestration/api/test_concurrency_limits_v2.py index ca53f402f4d14..d701bd7115982 100644 --- a/tests/server/orchestration/api/test_concurrency_limits_v2.py +++ b/tests/server/orchestration/api/test_concurrency_limits_v2.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.client import schemas as client_schemas -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.models.concurrency_limits_v2 import ( bulk_update_denied_slots, create_concurrency_limit, diff --git a/tests/server/orchestration/api/test_deployment_schedules.py b/tests/server/orchestration/api/test_deployment_schedules.py index ba450367ea5aa..ba72f60159d04 100644 --- a/tests/server/orchestration/api/test_deployment_schedules.py +++ b/tests/server/orchestration/api/test_deployment_schedules.py @@ -9,7 +9,7 @@ from httpx import AsyncClient from prefect.server import models, schemas -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from tests.server import AsyncSessionGetter diff --git a/tests/server/orchestration/api/test_flow_run_graph_v2.py b/tests/server/orchestration/api/test_flow_run_graph_v2.py index ec3487489596d..38b4baf51ff06 100644 --- a/tests/server/orchestration/api/test_flow_run_graph_v2.py +++ b/tests/server/orchestration/api/test_flow_run_graph_v2.py @@ -1,5 +1,4 @@ from collections import defaultdict -from datetime import datetime from operator import attrgetter from typing import Iterable, List, Union from unittest import mock @@ -13,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import models, schemas -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface from prefect.server.exceptions import FlowRunGraphTooLarge, ObjectNotFoundError from prefect.server.models.flow_runs import read_flow_run_graph from prefect.server.schemas.graph import Edge, Graph, GraphArtifact, GraphState, Node @@ -1438,9 +1437,7 @@ async def test_missing_flow_run_returns_404( assert response.status_code == 404, response.text model_method_mock.assert_awaited_once_with( - session=mock.ANY, - flow_run_id=flow_run_id, - since=datetime.min, + session=mock.ANY, flow_run_id=flow_run_id, since=pendulum.DateTime.min ) @@ -1455,9 +1452,7 @@ async def test_api_full( assert response.status_code == 200, response.text model_method_mock.assert_awaited_once_with( - session=mock.ANY, - flow_run_id=flow_run_id, - since=datetime.min, + session=mock.ANY, flow_run_id=flow_run_id, since=pendulum.DateTime.min ) assert response.json() == graph.model_dump(mode="json") diff --git a/tests/server/orchestration/api/test_workers.py b/tests/server/orchestration/api/test_workers.py index 7a3895db364e8..ead9fd6980040 100644 --- a/tests/server/orchestration/api/test_workers.py +++ b/tests/server/orchestration/api/test_workers.py @@ -1268,7 +1268,7 @@ async def test_work_pool_status_with_offline_worker( """Work pools with only offline workers should have a status of NOT_READY.""" now = pendulum.now("UTC") - insert_stmt = db.insert(db.Worker).values( + insert_stmt = db.queries.insert(db.Worker).values( name="old-worker", work_pool_id=work_pool.id, last_heartbeat_time=now.subtract(minutes=5), @@ -1458,7 +1458,7 @@ async def test_worker_with_old_heartbeat_has_offline_status( ): now = pendulum.now("UTC") - insert_stmt = db.insert(db.Worker).values( + insert_stmt = db.queries.insert(db.Worker).values( name="old-worker", work_pool_id=work_pool.id, last_heartbeat_time=now.subtract(minutes=5), @@ -1485,7 +1485,7 @@ async def test_worker_status_accounts_for_heartbeat_interval( """ now = pendulum.now("UTC") - insert_stmt = db.insert(db.Worker).values( + insert_stmt = db.queries.insert(db.Worker).values( name="old-worker", work_pool_id=work_pool.id, last_heartbeat_time=now.subtract(seconds=10), @@ -1512,7 +1512,7 @@ async def setup_workers(self, session, db, work_pool): worker_name="online-worker", ) - insert_stmt = db.insert(db.Worker).values( + insert_stmt = db.queries.insert(db.Worker).values( name="offline-worker", work_pool_id=work_pool.id, status="OFFLINE", @@ -1544,7 +1544,7 @@ async def test_delete_worker(self, client, work_pool, session, db): work_pool_id = work_pool.id deleted_worker_name = "worker1" for i in range(2): - insert_stmt = (db.insert(db.Worker)).values( + insert_stmt = (db.queries.insert(db.Worker)).values( name=f"worker{i}", work_pool_id=work_pool_id, last_heartbeat_time=pendulum.now(), @@ -1568,7 +1568,7 @@ async def test_nonexistent_worker(self, client, session, db): session=session, work_pool=schemas.actions.WorkPoolCreate(name="A"), ) - insert_stmt = (db.insert(db.Worker)).values( + insert_stmt = (db.queries.insert(db.Worker)).values( name=worker_name, work_pool_id=wp.id, last_heartbeat_time=pendulum.now(), diff --git a/tests/server/orchestration/test_rules.py b/tests/server/orchestration/test_rules.py index f8316d6b84c08..5830dba38273d 100644 --- a/tests/server/orchestration/test_rules.py +++ b/tests/server/orchestration/test_rules.py @@ -9,7 +9,7 @@ import sqlalchemy.exc from prefect.server import models, schemas -from prefect.server.database.dependencies import provide_database_interface +from prefect.server.database import provide_database_interface from prefect.server.exceptions import OrchestrationError from prefect.server.orchestration.rules import ( ALL_ORCHESTRATION_STATES, diff --git a/tests/server/services/test_foreman.py b/tests/server/services/test_foreman.py index 8d5283643eaa5..dd92237a9cb7a 100644 --- a/tests/server/services/test_foreman.py +++ b/tests/server/services/test_foreman.py @@ -9,8 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from prefect.server import models, schemas -from prefect.server.database.dependencies import db_injector -from prefect.server.database.interface import PrefectDBInterface +from prefect.server.database import PrefectDBInterface, db_injector from prefect.server.events.clients import AssertingEventsClient from prefect.server.schemas.statuses import DeploymentStatus from prefect.server.services.foreman import Foreman diff --git a/tests/server/utilities/test_database.py b/tests/server/utilities/test_database.py index eb3de114640bf..02ff6cf61a94e 100644 --- a/tests/server/utilities/test_database.py +++ b/tests/server/utilities/test_database.py @@ -14,11 +14,16 @@ from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import Mapped, declarative_base, mapped_column +from prefect.server.database import PrefectDBInterface from prefect.server.database.configurations import AioSqliteConfiguration -from prefect.server.database.interface import PrefectDBInterface from prefect.server.database.orm_models import AioSqliteORMConfiguration from prefect.server.database.query_components import AioSqliteQueryComponents -from prefect.server.utilities.database import JSON, Pydantic, Timestamp +from prefect.server.utilities.database import ( + JSON, + Pydantic, + Timestamp, + bindparams_from_clause, +) DBBase = declarative_base(type_annotation_map={pendulum.DateTime: Timestamp}) @@ -615,3 +620,9 @@ async def test_error_thrown_if_sqlite_version_is_below_minimum(): orm=AioSqliteORMConfiguration(), ) await db.engine() + + +def test_bind_params_from_clause() -> None: + bp = sa.bindparam("foo", 42, sa.Integer) + statement = 17 < bp + assert bindparams_from_clause(statement) == {"foo": bp} diff --git a/ui-v2/src/api/prefect.ts b/ui-v2/src/api/prefect.ts index bcbff9e8b8411..5da3bf85b1970 100644 --- a/ui-v2/src/api/prefect.ts +++ b/ui-v2/src/api/prefect.ts @@ -7212,11 +7212,8 @@ export interface components { }; /** Graph */ Graph: { - /** - * Start Time - * Format: date-time - */ - start_time: string; + /** Start Time */ + start_time: string | null; /** End Time */ end_time: string | null; /** Root Node Ids */ @@ -7246,7 +7243,7 @@ export interface components { /** Key */ key: string | null; /** Type */ - type: string; + type: string | null; /** Is Latest */ is_latest: boolean; /** Data */