diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4c48dc5e..02ffb996 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,13 @@ Change log .. Headline template: X.Y.Z (YYYY-MM-DD) +3.2.0 (2022-05-13) +================== + +- New support for `arq `__, the Redis-based asyncio distributed queue package. + The ``safir.arq`` module provides an arq client and metadata/result data classes with a mock implementation for testing. + The FastAPI dependency, ``safir.dependencies.arq.arq_dependency``, provides a convenient way to use the arq client from HTTP handlers. + 3.1.0 (2022-06-01) ================== diff --git a/Makefile b/Makefile index b5033d75..0ef0f753 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ .PHONY: init init: pip install --upgrade pip tox tox-docker pre-commit - pip install --upgrade -e ".[db,dev,kubernetes]" + pip install --upgrade -e ".[arq,db,dev,kubernetes]" pre-commit install + rm -rf .tox diff --git a/README.rst b/README.rst index 91223679..227def31 100644 --- a/README.rst +++ b/README.rst @@ -23,6 +23,7 @@ Features - Middleware for attaching request context to the logger to include a request UUID, method, and route in all log messages. - Process ``X-Forwarded-*`` headers to determine the source IP and related information of the request. - Gather and structure standard metadata about your application. +- Operate a distributed Redis job queue with arq_ using convenient clients, testing mocks, and a FastAPI dependency. Developing Safir ================ @@ -40,3 +41,4 @@ For details, see https://safir.lsst.io/dev/development.html. .. _Roundtable: https://roundtable.lsst.io .. _FastAPI: https://fastapi.tiangolo.com/ .. _fastapi_safir_app: https://github.com/lsst/templates/tree/master/project_templates/fastapi_safir_app +.. _arq: https://arq-docs.helpmanual.io diff --git a/docs/api.rst b/docs/api.rst index 5e100473..6eb7e4be 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -5,8 +5,14 @@ API reference .. automodapi:: safir :include-all-objects: +.. automodapi:: safir.arq + :include-all-objects: + .. automodapi:: safir.database +.. automodapi:: safir.dependencies.arq + :include-all-objects: + .. automodapi:: safir.dependencies.db_session :include-all-objects: diff --git a/docs/arq.rst b/docs/arq.rst new file mode 100644 index 00000000..2cba1ca9 --- /dev/null +++ b/docs/arq.rst @@ -0,0 +1,288 @@ +.. currentmodule:: safir.arq + +############################################### +Using the arq Redis queue client and dependency +############################################### + +Distributed queues allow your application to decouple slow-running processing tasks from your user-facing endpoint handlers. +arq_ is a simple distributed queue library with an asyncio API that uses Redis to store both queue metadata and results. +To simplify integrating arq_ into your FastAPI application and test suites, Safir both an arq client (`~safir.arq.ArqQueue`) with a drop-in mock for testing and an endpoint handler dependency (`safir.dependencies.arq`) that provides an arq_ client. + +For information on using arq in general, see the `arq documentation `_. +For real-world examples of how this dependency, and arq-based distributed queues in general are used in FastAPI apps, see our `Times Square `__ and `Noteburst `__ applications. + +Quick start +=========== + +.. _arq-dependency-setup: + +Dependency set up and configuration +----------------------------------- + +In your application's FastAPI setup module, typically :file:`main.py`, you need to initialize `safir.dependencies.arq.ArqDependency` during the start up event: + +.. code-block:: python + + from fastapi import Depends, FastAPI + from safir.dependencies.arq import arq_dependency + + app = FastAPI() + + + @app.on_event("startup") + async def startup() -> None: + await arq_dependency.initialize( + mode=config.arq_mode, redis_settings=config.arq_redis_settings + ) + +The ``mode`` parameter for `safir.dependencies.arq.ArqDependency.initialize` takes `ArqMode` enum values of either ``"production"`` or ``"test"``. The ``"production"`` mode configures a real arq_ queue backed by Redis, whereas ``"test"`` configures a mock version of the arq_ queue. + +Running under the regular ``"production"`` mode, you need to provide a `arq.connections.RedisSettings` instance. +If your app uses a configuration system like ``pydantic.BaseSettings``, this example ``Config`` class shows how to create a `~arq.connections.RedisSettings` object from a regular Redis URI: + +.. code-block:: python + + from urllib.parse import urlparse + + from arq.connections import RedisSettings + from pydantic import BaseSettings, Field, RedisDsn + from safir.arq import ArqMode + + + class Config(BaseSettings): + + arq_queue_url: RedisDsn = Field( + "redis://localhost:6379/1", env="APP_ARQ_QUEUE_URL" + ) + + arq_mode: ArqMode = Field(ArqMode.production, env="APP_ARQ_MODE") + + @property + def arq_redis_settings(self) -> RedisSettings: + """Create a Redis settings instance for arq.""" + url_parts = urlparse(self.redis_queue_url) + redis_settings = RedisSettings( + host=url_parts.hostname or "localhost", + port=url_parts.port or 6379, + database=int(url_parts.path.lstrip("/")) if url_parts.path else 0, + ) + return redis_settings + +Worker set up +------------- + +Workers that run queued tasks are separate application deployments, though they can (but don't necessarily need to) operate from the same codebase as the FastAPI-based front-end application. +A convenient pattern is to co-locate the worker inside a ``worker`` sub-package: + +.. code-block:: text + + . + ├── src + │   └── yourapp + │   ├── __init__.py + │   ├── config.py + │   ├── main.py + │   └── worker + │   ├── __init__.py + │   ├── functions + │   │   ├── __init__.py + │   │   ├── function_a.py + │   │   └── function_b.py + │   ├── main.py + +The :file:`src/yourapp/worker/main.py` module looks like: + +.. code-block:: python + + from __future__ import annotations + + import uuid + from typing import Any, Dict + + import httpx + import structlog + from safir.logging import configure_logging + + from ..config import config + from .functions import function_a, function_b + + + async def startup(ctx: Dict[Any, Any]) -> None: + """Runs during worker start-up to set up the worker context.""" + configure_logging( + profile=config.profile, + log_level=config.log_level, + name="yourapp", + ) + logger = structlog.get_logger("yourapp") + # The instance key uniquely identifies this worker in logs + instance_key = uuid.uuid4().hex + logger = logger.bind(worker_instance=instance_key) + + http_client = httpx.AsyncClient() + ctx["http_client"] = http_client + + ctx["logger"] = logger + logger.info("Worker start up complete") + + + async def shutdown(ctx: Dict[Any, Any]) -> None: + """Runs during worker shutdown to cleanup resources.""" + if "logger" in ctx.keys(): + logger = ctx["logger"] + else: + logger = structlog.get_logger("yourapp") + logger.info("Running worker shutdown.") + + try: + await ctx["http_client"].aclose() + except Exception as e: + logger.warning("Issue closing the http_client: %s", str(e)) + + logger.info("Worker shutdown complete.") + + + class WorkerSettings: + """Configuration for the arq worker. + + See `arq.worker.Worker` for details on these attributes. + """ + + functions = [function_a, function_b] + + redis_settings = config.arq_redis_settings + + on_startup = startup + + on_shutdown = shutdown + +The ``WorkerSettings`` class is where you configure the queue and declare worker functions. +See `arq.worker.Worker` for details. + +The ``on_startup`` and ``on_shutdown`` handlers are ideal places to set up (and tear down) worker state, including network and database clients. +The context variable, ``ctx``, passed to these functions are also passed to the worker functions. + +To run a worker, you run your application's Docker image with the ``arq`` command, followed by the fully-qualified namespace of the ``WorkerSettings`` class. + +Using the arq dependency in endpoint handlers +--------------------------------------------- + +The `safir.dependencies.arq.arq_dependency` dependency provides your FastAPI endpoint handlers with an `ArqQueue` client that you can use to add jobs (`ArqQueue.enqueue`) to the queue, and get metadata (`ArqQueue.get_job_metadata`) and results (`ArqQueue.get_job_result`) from the queue: + +.. code-block:: python + + from fastapi import Depends, HTTPException + from safir.arq import ArqQueue + from safir.dependencies.arq import arq_dependency + + + @app.post("/jobs") + async def post_job( + arq_queue: ArqQueue = Depends(arq_dependency), + a: str = "hello", + b: int = 42, + ) -> Dict[str, Any]: + """Create a job.""" + job = await arq_queue.enqueue("test_task", a, a_number=b) + return {"job_id": job.id} + + + @app.get("/jobs/{job_id}") + async def get_job( + job_id: str, + arq_queue: ArqQueue = Depends(arq_dependency), + ) -> Dict[str, Any]: + """Get metadata about a job.""" + try: + job = await arq_queue.get_job_metadata(job_id, queue_name=queue_name) + except JobNotFound: + raise HTTPException(status_code=404) + + response = { + "id": job.id, + "status": job.status, + "name": job.name, + "args": job.args, + "kwargs": job.kwargs, + } + + if job.status == JobStatus.complete: + try: + job_result = await arq_queue.get_job_result( + job_id, queue_name=queue_name + ) + except (JobNotFound, JobResultUnavailable): + raise HTTPException(status_code=404) + response["result"] = job_result.result + + return response + +For information on the metadata available from jobs, see `JobMetadata` and `JobResult`. + +Testing applications with an arq queue +====================================== + +Unit testing an application with a running distributed queue is difficult since three components (two instances of the application and a redis database) must coordinate. +A better unit testing approach is to test the front-end application separately from the worker functions. +To help you do this, the arq dependency allows you to run a mocked version of an arq queue. +With the mocked client, your front-end application can run the three basic client methods as normal: `ArqQueue.enqueue`, `ArqQueue.get_job_metadata`, and `ArqQueue.get_job_result`). +This mocked client is a subclass of `ArqQueue` called `MockArqQueue`. + +Configuring the test mode +------------------------- + +You get a `MockArqQueue` from the `safir.dependencies.arq.arq_dependency` instance by passing a `ArqMode.test` value to the ``mode`` argument of `safir.dependencies.arq.ArqDependency.initialize` in your application's start up (see :ref:`arq-dependency-setup`). +As the above example shows, you can make this an environment variable configuration, and then set the arq mode in your tox settings. + +Interacting with the queue state +-------------------------------- + +Your tests can add jobs and get job metadata or results using the normal code paths. +Since queue jobs never run, your test code needs to manually change the status of jobs and set job results. +You can do this by manually calling the `safir.dependencies.arq.arq_dependency` instance from your test (a `MockArqQueue`) and using the `MockArqQueue.set_in_progress` and `MockArqQueue.set_complete` methods. + +This example adapted from Noteburst shows how this works: + +.. code-block:: python + + from safir.arq import MockArqQueue + from safir.dependencies.arq import arq_dependency + + + @pytest.mark.asyncio + async def test_post_nbexec( + client: AsyncClient, sample_ipynb: str, sample_ipynb_executed: str + ) -> None: + arq_queue = await arq_dependency() + assert isinstance(arq_queue, MockArqQueue) + + response = await client.post( + "/noteburst/v1/notebooks/", + json={ + "ipynb": sample_ipynb, + "kernel_name": "LSST", + }, + ) + assert response.status_code == 202 + data = response.json() + assert data["status"] == "queued" + job_url = response.headers["Location"] + job_id = data["job_id"] + + # Toggle the job to in-progress; the status should update + await arq_queue.set_in_progress(job_id) + + response = await client.get(job_url) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "in_progress" + + # Toggle the job to complete + await arq_queue.set_complete(job_id, result=sample_ipynb_executed) + + response = await client.get(job_url) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "complete" + assert data["success"] is True + assert data["ipynb"] == sample_ipynb_executed diff --git a/docs/conf.py b/docs/conf.py index a7af4781..5dea3de6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,6 +7,7 @@ rst_epilog = """ +.. _arq: https://arq-docs.helpmanual.io .. _FastAPI: https://fastapi.tiangolo.com/ .. _mypy: http://www.mypy-lang.org .. _pre-commit: https://pre-commit.com @@ -72,6 +73,7 @@ "python": ("https://docs.python.org/3/", None), "sqlalchemy": ("https://docs.sqlalchemy.org/en/latest/", None), "structlog": ("https://www.structlog.org/en/stable/", None), + "arq": ("https://arq-docs.helpmanual.io", None), } intersphinx_timeout = 10.0 # seconds @@ -94,6 +96,7 @@ ("py:class", "starlette.middleware.base.BaseHTTPMiddleware"), ("py:class", "starlette.requests.Request"), ("py:class", "starlette.responses.Response"), + ("py:obj", "JobMetadata.id"), ] # Linkcheck builder ========================================================== diff --git a/docs/index.rst b/docs/index.rst index 1813b7ad..9e8959d6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Guides .. toctree:: :maxdepth: 2 + arq database http-client gafaelfawr diff --git a/setup.cfg b/setup.cfg index 2ffe48b8..5a45953e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -66,6 +66,9 @@ dev = sphinx-prompt kubernetes = kubernetes_asyncio +arq = + # 0.23 or later includes support for aioredis 2, but is not released yet. + arq==0.23a1 [flake8] max-line-length = 79 diff --git a/src/safir/arq.py b/src/safir/arq.py new file mode 100644 index 00000000..5080a0a9 --- /dev/null +++ b/src/safir/arq.py @@ -0,0 +1,543 @@ +"""An `arq `__ client with a mock for +testing. +""" + +from __future__ import annotations + +import abc +import uuid +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Dict, Optional, Tuple + +from arq import create_pool +from arq.connections import ArqRedis, RedisSettings +from arq.constants import default_queue_name as arq_default_queue_name +from arq.jobs import Job, JobStatus + +__all__ = [ + "ArqJobError", + "JobNotQueued", + "JobNotFound", + "JobResultUnavailable", + "ArqMode", + "JobMetadata", + "JobResult", + "ArqQueue", + "RedisArqQueue", + "MockArqQueue", +] + + +class ArqJobError(Exception): + """A base class for errors related to arq jobs. + + Attributes + ---------- + job_id : `str`, optional + The job ID, or `None` if the job ID is not known in this context. + """ + + def __init__(self, message: str, job_id: Optional[str]) -> None: + super().__init__(message) + self._job_id = job_id + + @property + def job_id(self) -> Optional[str]: + """The job ID, or `None` if the job ID is not known in this context.""" + return self._job_id + + +class JobNotQueued(ArqJobError): + """The job was not successfully queued.""" + + def __init__(self, job_id: Optional[str]) -> None: + super().__init__( + f"Job was not queued because it already exists. id={job_id}", + job_id, + ) + + +class JobNotFound(ArqJobError): + """A job cannot be found.""" + + def __init__(self, job_id: str) -> None: + super().__init__(f"Job could not be found. id={job_id}", job_id) + + +class JobResultUnavailable(ArqJobError): + """The job's result is unavailable.""" + + def __init__(self, job_id: str) -> None: + super().__init__(f"Job result could not be found. id={job_id}", job_id) + + +class ArqMode(str, Enum): + """Mode configuration for the Arq queue.""" + + production = "production" + """Normal usage of arq, with a Redis broker.""" + + test = "test" + """Use the MockArqQueue to test an API service without standing up a + full distributed worker queue. + """ + + +@dataclass +class JobMetadata: + """Information about a queued job. + + Attributes + ---------- + id : str + The `arq.jobs.Job` identifier + name: str + The task name. + args: Any + The positional arguments to the task function. + kwargs: Any + The keyword arguments to the task function. + enqueue_time: datetime.datetime + Time when the job was added to the queue. + status: str + Status of the job. + + States are defined by the `arq.jobs.JobStatus` enumeration: + + - ``deferred`` (in queue, but waiting a predetermined time to become + ready to run) + - ``queued`` (queued to run) + - ``in_progress`` (actively being run by a worker) + - ``complete`` (result is available) + - ``not_found`` (the job cannot be found) + queue_name: str + Name of the queue this job belongs to. + """ + + id: str + """The `~arq.jobs.Job` identifier.""" + + name: str + """The task name.""" + + args: Tuple[Any, ...] + """The positional arguments to the task function.""" + + kwargs: Dict[str, Any] + """The keyword arguments to the task function.""" + + enqueue_time: datetime + """Time when the job was added to the queue.""" + + status: JobStatus + """Status of the job. + + States are defined by the `arq.jobs.JobStatus` enumeration: + + - ``deferred`` (in queue, but waiting a predetermined time to become + ready to run) + - ``queued`` (queued to run) + - ``in_progress`` (actively being run by a worker) + - ``complete`` (result is available) + - ``not_found`` (the job cannot be found) + """ + + queue_name: str + """Name of the queue this job belongs to.""" + + @classmethod + async def from_job(cls, job: Job) -> JobMetadata: + """Initialize JobMetadata from an arq Job. + + Raises + ------ + JobNotFound + Raised if the job is not found + """ + job_info = await job.info() + if job_info is None: + raise JobNotFound(job.job_id) + + job_status = await job.status() + if job_status == JobStatus.not_found: + raise JobNotFound(job.job_id) + + return cls( + id=job.job_id, + name=job_info.function, + args=job_info.args, + kwargs=job_info.kwargs, + enqueue_time=job_info.enqueue_time, + status=job_status, + # private attribute of Job; not available in JobDef + # queue_name is available in JobResult + queue_name=job._queue_name, + ) + + +@dataclass +class JobResult(JobMetadata): + """The full result of a job, as well as its metadata. + + Attributes + ---------- + id : str + The `~arq.jobs.Job` identifier + name: str + The task name. + args: Any + The positional arguments to the task function. + kwargs: Any + The keyword arguments to the task function. + enqueue_time: datetime.datetime + Time when the job was added to the queue. + status: str + Status of the job. + + States are defined by the `arq.jobs.JobStatus` enumeration: + + - ``deferred`` (in queue, but waiting a predetermined time to become + ready to run) + - ``queued`` (queued to run) + - ``in_progress`` (actively being run by a worker) + - ``complete`` (result is available) + - ``not_found`` (the job cannot be found) + start_time: datetime.datetime + Time when the job started. + finish_time: datetime.datetime + Time when the job finished. + success: bool + `True` if the job returned without an exception, `False` if an + exception was raised. + result: Any + The job's result. + """ + + start_time: datetime + """Time when the job started.""" + + finish_time: datetime + """Time when the job finished.""" + + success: bool + """`True` if the job returned without an exception, `False` if an + exception was raised. + """ + + result: Any + """The job's result.""" + + @classmethod + async def from_job(cls, job: Job) -> JobResult: + """Initialize the `JobResult` from an arq `~arq.jobs.Job`. + + Raises + ------ + JobNotFound + Raised if the job is not found + JobResultUnavailable + Raised if the job result is not available. + """ + job_info = await job.info() + if job_info is None: + raise JobNotFound(job.job_id) + + job_status = await job.status() + if job_status == JobStatus.not_found: + raise JobNotFound(job.job_id) + + # Result may be none if the job isn't finished + result_info = await job.result_info() + if result_info is None: + raise JobResultUnavailable(job.job_id) + + return cls( + id=job.job_id, + name=job_info.function, + args=job_info.args, + kwargs=job_info.kwargs, + enqueue_time=job_info.enqueue_time, + start_time=result_info.start_time, + finish_time=result_info.finish_time, + success=result_info.success, + status=job_status, + queue_name=result_info.queue_name, + result=result_info.result, + ) + + +class ArqQueue(metaclass=abc.ABCMeta): + """An common interface for working with an arq queue that can be + implemented either with a real Redis backend, or an in-memory repository + for testing. + + See also + -------- + RedisArqQueue + Production implementation with a Redis store. + MockArqQueue + In-memory implementation for testing and development. + """ + + def __init__( + self, *, default_queue_name: str = arq_default_queue_name + ) -> None: + self._default_queue_name = default_queue_name + + @property + def default_queue_name(self) -> str: + """Name of the default queue, if the ``_queue_name`` parameter is + no set in method calls. + """ + return self._default_queue_name + + @abc.abstractmethod + async def enqueue( + self, + task_name: str, + *task_args: Any, + _queue_name: Optional[str] = None, + **task_kwargs: Any, + ) -> JobMetadata: + """Add a job to the queue. + + Parameters + ---------- + task_name : `str` + The function name to run. + *args + Positional arguments for the task function. + _queue_name : `str` + Name of the queue. + **kwargs + Keyword arguments passed to the task function. + + Returns + ------- + JobMetadata + Metadata about the queued job. + + Raises + ------ + JobNotQueued + Raised if the job is not successfully added to the queue. + """ + raise NotImplementedError + + @abc.abstractmethod + async def get_job_metadata( + self, job_id: str, queue_name: Optional[str] = None + ) -> JobMetadata: + """Get metadata about a `~arq.jobs.Job`. + + Parameters + ---------- + job_id : `str` + The job's identifier. This is the same as the `JobMetadata.id` + attribute, provided when initially adding a job to the queue. + queue_name : `str`, optional + Name of the queue. + + Returns + ------- + `JobMetadata` + Metadata about the queued job. + + Raises + ------ + JobNotFound + Raised if the job is not found in the queue. + """ + raise NotImplementedError + + @abc.abstractmethod + async def get_job_result( + self, job_id: str, queue_name: Optional[str] = None + ) -> JobResult: + """The job result, if available. + + Parameters + ---------- + job_id : `str` + The job's identifier. This is the same as the `JobMetadata.id` + attribute, provided when initially adding a job to the queue. + queue_name : `str`, optional + Name of the queue. + + Returns + ------- + `JobResult` + The job's result, along with metadata about the queued job. + + Raises + ------ + JobNotFound + Raised if the job is not found in the queue. + JobResultUnavailable + Raised if the job's result is unavailable for any reason. + """ + raise NotImplementedError + + +class RedisArqQueue(ArqQueue): + """A distributed queue, based on arq and Redis.""" + + def __init__( + self, + pool: ArqRedis, + *, + default_queue_name: str = arq_default_queue_name, + ) -> None: + super().__init__(default_queue_name=default_queue_name) + self._pool = pool + + @classmethod + async def initialize( + cls, + redis_settings: RedisSettings, + *, + default_queue_name: str = arq_default_queue_name, + ) -> RedisArqQueue: + """Initialize a RedisArqQueue from Redis settings.""" + pool = await create_pool( + redis_settings, default_queue_name=default_queue_name + ) + return cls(pool) + + async def enqueue( + self, + task_name: str, + *task_args: Any, + _queue_name: Optional[str] = None, + **task_kwargs: Any, + ) -> JobMetadata: + job = await self._pool.enqueue_job( + task_name, + *task_args, + _queue_name=_queue_name or self.default_queue_name, + **task_kwargs, + ) + if job: + return await JobMetadata.from_job(job) + else: + # TODO if implementing hard-coded job IDs, set as an argument + raise JobNotQueued(None) + + def _get_job(self, job_id: str, queue_name: Optional[str] = None) -> Job: + return Job( + job_id, + self._pool, + _queue_name=queue_name or self.default_queue_name, + ) + + async def get_job_metadata( + self, job_id: str, queue_name: Optional[str] = None + ) -> JobMetadata: + job = self._get_job(job_id, queue_name=queue_name) + return await JobMetadata.from_job(job) + + async def get_job_result( + self, job_id: str, queue_name: Optional[str] = None + ) -> JobResult: + job = self._get_job(job_id, queue_name=queue_name) + return await JobResult.from_job(job) + + +class MockArqQueue(ArqQueue): + """A mocked queue for testing API services.""" + + def __init__( + self, *, default_queue_name: str = arq_default_queue_name + ) -> None: + super().__init__(default_queue_name=default_queue_name) + self._job_metadata: Dict[str, Dict[str, JobMetadata]] = { + self.default_queue_name: {} + } + self._job_results: Dict[str, Dict[str, JobResult]] = { + self.default_queue_name: {} + } + + def _resolve_queue_name(self, queue_name: Optional[str]) -> str: + return queue_name or self.default_queue_name + + async def enqueue( + self, + task_name: str, + *task_args: Any, + _queue_name: Optional[str] = None, + **task_kwargs: Any, + ) -> JobMetadata: + queue_name = self._resolve_queue_name(_queue_name) + new_job = JobMetadata( + id=str(uuid.uuid4().hex), + name=task_name, + args=task_args, + kwargs=task_kwargs, + enqueue_time=datetime.now(), + status=JobStatus.queued, + queue_name=queue_name, + ) + self._job_metadata[queue_name][new_job.id] = new_job + return new_job + + async def get_job_metadata( + self, job_id: str, queue_name: Optional[str] = None + ) -> JobMetadata: + queue_name = self._resolve_queue_name(queue_name) + try: + return self._job_metadata[queue_name][job_id] + except KeyError: + raise JobNotFound(job_id) + + async def get_job_result( + self, job_id: str, queue_name: Optional[str] = None + ) -> JobResult: + queue_name = self._resolve_queue_name(queue_name) + try: + return self._job_results[queue_name][job_id] + except KeyError: + raise JobResultUnavailable(job_id) + + async def set_in_progress( + self, job_id: str, queue_name: Optional[str] = None + ) -> None: + """Set a job's status to in progress, for mocking a queue in tests.""" + job = await self.get_job_metadata(job_id, queue_name=queue_name) + job.status = JobStatus.in_progress + + # An in-progress job cannot have a result + if job_id in self._job_results: + del self._job_results[job_id] + + async def set_complete( + self, + job_id: str, + *, + result: Any, + success: bool = True, + queue_name: Optional[str] = None, + ) -> None: + """Set a job's result, for mocking a queue in tests.""" + queue_name = self._resolve_queue_name(queue_name) + + job_metadata = await self.get_job_metadata( + job_id, queue_name=queue_name + ) + job_metadata.status = JobStatus.complete + + result_info = JobResult( + id=job_metadata.id, + name=job_metadata.name, + args=job_metadata.args, + kwargs=job_metadata.kwargs, + status=job_metadata.status, + enqueue_time=job_metadata.enqueue_time, + start_time=datetime.now(), + finish_time=datetime.now(), + result=result, + success=success, + queue_name=queue_name, + ) + self._job_results[queue_name][job_id] = result_info diff --git a/src/safir/dependencies/arq.py b/src/safir/dependencies/arq.py new file mode 100644 index 00000000..957a49ab --- /dev/null +++ b/src/safir/dependencies/arq.py @@ -0,0 +1,92 @@ +"""A FastAPI dependency that supplies a Redis connection for `arq +`__. +""" + +from __future__ import annotations + +from typing import Optional + +from arq.connections import RedisSettings + +from ..arq import ArqMode, ArqQueue, MockArqQueue, RedisArqQueue + +__all__ = ["ArqDependency", "arq_dependency"] + + +class ArqDependency: + """A FastAPI dependency that maintains a Redis client for enqueing + tasks to the worker pool. + """ + + def __init__(self) -> None: + self._arq_queue: Optional[ArqQueue] = None + + async def initialize( + self, *, mode: ArqMode, redis_settings: Optional[RedisSettings] + ) -> None: + """Initialize the dependency (call during the FastAPI start-up event). + + Parameters + ---------- + mode : `safir.arq.ArqMode` + The mode to operate the queue dependency in. With + `safir.arq.ArqMode.production`, this method initializes a + Redis-based arq queue and the dependency creates a + `safir.arq.RedisArqQueue` client. + + With `safir.arq.ArqMode.test`, this method instead initializes an + in-memory mocked version of arq that you use with the + `safir.arq.MockArqQueue` client. + redis_settings : `arq.connections.RedisSettings` + The arq Redis settings, required when the ``mode`` is + `safir.arq.ArqMode.production`. See arq's + `~arq.connections.RedisSettings` documentation for details on + this object. + + Examples + -------- + .. code-block:: python + + from fastapi import Depends, FastAPI + from safir.arq import ArqMode, ArqQueue + from safir.dependencies.arq import arq_dependency + + app = FastAPI() + + + @app.on_event("startup") + async def startup() -> None: + await arq_dependency.initialize(mode=ArqMode.test) + + + @app.post("/") + async def post_job( + arq_queue: ArqQueue = Depends(arq_dependency), + ) -> Dict[str, Any]: + job = await arq_queue.enqueue("test_task", "hello", an_int=42) + return {"job_id": job.id} + """ + if mode == ArqMode.production: + if not redis_settings: + raise RuntimeError( + "The redis_settings argument must be set for arq in " + "production." + ) + self._arq_queue = await RedisArqQueue.initialize(redis_settings) + else: + self._arq_queue = MockArqQueue() + + async def __call__(self) -> ArqQueue: + """Get the arq queue. + + This method is called for your by ``fastapi.Depends``. + """ + if self._arq_queue is None: + raise RuntimeError("ArqDependency is not initialized") + return self._arq_queue + + +arq_dependency = ArqDependency() +"""Singleton instance of `ArqDependency` that serves as a FastAPI +dependency. +""" diff --git a/tests/dependencies/arq_test.py b/tests/dependencies/arq_test.py new file mode 100644 index 00000000..7888fcdb --- /dev/null +++ b/tests/dependencies/arq_test.py @@ -0,0 +1,156 @@ +"""Test `safir.dependencies.arq`.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +import pytest +from arq.constants import default_queue_name +from asgi_lifespan import LifespanManager +from fastapi import Depends, FastAPI, HTTPException +from httpx import AsyncClient + +from safir.arq import ArqMode, JobNotFound, JobResultUnavailable, MockArqQueue +from safir.dependencies.arq import arq_dependency + + +@pytest.mark.asyncio +async def test_arq_dependency_mock() -> None: + """Test the arq dependency entirely through the MockArqQueue.""" + app = FastAPI() + + @app.post("/") + async def post_job( + arq_queue: MockArqQueue = Depends(arq_dependency), + ) -> Dict[str, Any]: + """Create a job.""" + job = await arq_queue.enqueue("test_task", "hello", a_number=42) + return { + "job_id": job.id, + "job_status": job.status, + "job_name": job.name, + "job_args": job.args, + "job_kwargs": job.kwargs, + "job_queue_name": job.queue_name, + } + + @app.get("/jobs/{job_id}") + async def get_metadata( + job_id: str, + queue_name: Optional[str] = None, + arq_queue: MockArqQueue = Depends(arq_dependency), + ) -> Dict[str, Any]: + """Get metadata about a job.""" + try: + job = await arq_queue.get_job_metadata( + job_id, queue_name=queue_name + ) + except JobNotFound: + raise HTTPException(status_code=404) + return { + "job_id": job.id, + "job_status": job.status, + "job_name": job.name, + "job_args": job.args, + "job_kwargs": job.kwargs, + "job_queue_name": job.queue_name, + } + + @app.get("/results/{job_id}") + async def get_result( + job_id: str, + queue_name: Optional[str] = None, + arq_queue: MockArqQueue = Depends(arq_dependency), + ) -> Dict[str, Any]: + """Get the results for a job.""" + try: + job_result = await arq_queue.get_job_result( + job_id, queue_name=queue_name + ) + except (JobNotFound, JobResultUnavailable) as e: + raise HTTPException(status_code=404, detail=str(e)) + return { + "job_id": job_result.id, + "job_status": job_result.status, + "job_name": job_result.name, + "job_args": job_result.args, + "job_kwargs": job_result.kwargs, + "job_queue_name": job_result.queue_name, + "job_result": job_result.result, + "job_success": job_result.success, + } + + @app.post("/jobs/{job_id}/inprogress") + async def post_job_inprogress( + job_id: str, + queue_name: Optional[str] = None, + arq_queue: MockArqQueue = Depends(arq_dependency), + ) -> None: + """Toggle a job to in-progress, for testing.""" + try: + await arq_queue.set_in_progress(job_id, queue_name=queue_name) + except JobNotFound as e: + raise HTTPException(status_code=404, detail=str(e)) + + @app.post("/jobs/{job_id}/complete") + async def post_job_complete( + job_id: str, + queue_name: Optional[str] = None, + result: Optional[str] = None, + success: bool = True, + arq_queue: MockArqQueue = Depends(arq_dependency), + ) -> None: + """Toggle a job to complete, for testing.""" + try: + await arq_queue.set_complete( + job_id, result=result, success=success, queue_name=queue_name + ) + except JobNotFound as e: + raise HTTPException(status_code=404, detail=str(e)) + + @app.on_event("startup") + async def startup() -> None: + await arq_dependency.initialize(mode=ArqMode.test, redis_settings=None) + + async with LifespanManager(app): + async with AsyncClient(app=app, base_url="http://example.com") as c: + r = await c.post("/") + assert r.status_code == 200 + data = r.json() + job_id = data["job_id"] + assert data["job_status"] == "queued" + assert data["job_name"] == "test_task" + assert data["job_args"] == ["hello"] + assert data["job_kwargs"] == {"a_number": 42} + assert data["job_queue_name"] == default_queue_name + + r = await c.get(f"/jobs/{job_id}") + assert r.status_code == 200 + assert data["job_kwargs"] == {"a_number": 42} + + # Wrong queue name + r = await c.get(f"/jobs/{job_id}?queue_name=queue2") + assert r.status_code == 404 + + # Result should not be available + r = await c.get(f"/results/{job_id}") + assert r.status_code == 404 + data = r.json() + assert data["detail"] == ( + f"Job result could not be found. id={job_id}" + ) + + # Set to in-progress + r = await c.post(f"/jobs/{job_id}/inprogress") + r = await c.get(f"/jobs/{job_id}") + data = r.json() + assert data["job_status"] == "in_progress" + + # Set to successful completion + r = await c.post(f"/jobs/{job_id}/complete?result=done") + r = await c.get(f"/results/{job_id}") + assert r.status_code == 200 + data = r.json() + assert data["job_status"] == "complete" + assert data["job_result"] == "done" + assert data["job_success"] is True diff --git a/tox.ini b/tox.ini index 37bebed6..379e6c91 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ extras = db dev kubernetes + arq [testenv:py] description = Run pytest with PostgreSQL via Docker.