From 463621fc2c43f952b66518de3d38ed3298e8f9f2 Mon Sep 17 00:00:00 2001 From: David Dias <32164122+diasdm@users.noreply.github.com> Date: Fri, 13 Dec 2024 19:02:48 +0000 Subject: [PATCH] Add tags parameter to decorator (#57) --- aiohttp_apischema/generator.py | 21 +++++++++++++++------ tests/test_generator.py | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/aiohttp_apischema/generator.py b/aiohttp_apischema/generator.py index 91eae5b..1c9f402 100644 --- a/aiohttp_apischema/generator.py +++ b/aiohttp_apischema/generator.py @@ -5,7 +5,7 @@ from http import HTTPStatus from pathlib import Path from types import UnionType -from typing import Any, Literal, TypedDict, TypeGuard, TypeVar, cast, get_args, get_origin +from typing import Any, Iterable, Literal, TypedDict, TypeGuard, TypeVar, cast, get_args, get_origin from aiohttp import web from aiohttp.hdrs import METH_ALL @@ -59,6 +59,7 @@ class _EndpointData(TypedDict, total=False): desc: str resps: dict[int, TypeAdapter[Any]] summary: str + tags: list[str] class _Endpoint(TypedDict, total=False): desc: str @@ -84,6 +85,7 @@ class _OperationObject(TypedDict, total=False): requestBody: _RequestBodyObject responses: dict[str, _ResponseObject] summary: str + tags: list[str] class _PathObject(TypedDict, total=False): delete: _OperationObject @@ -143,7 +145,7 @@ def __init__(self, info: Info | None = None): info = {"title": "API", "version": "1.0"} self._openapi: _OpenApi = {"openapi": "3.1.0", "info": info} - def _save_handler(self, handler: APIHandler[APIResponse[object, int]]) -> _EndpointData: + def _save_handler(self, handler: APIHandler[APIResponse[object, int]], tags: list[str]) -> _EndpointData: ep_data: _EndpointData = {} docs = inspect.getdoc(handler) if docs: @@ -154,6 +156,8 @@ def _save_handler(self, handler: APIHandler[APIResponse[object, int]]) -> _Endpo ep_data["summary"] = summary if desc: ep_data["desc"] = desc + if tags: + ep_data["tags"] = tags sig = inspect.signature(handler, eval_str=True) params = iter(sig.parameters.values()) @@ -184,7 +188,7 @@ def _save_handler(self, handler: APIHandler[APIResponse[object, int]]) -> _Endpo return ep_data - def api_view(self) -> Callable[[type[_View]], type[_View]]: + def api_view(self, tags: Iterable[str] = ()) -> Callable[[type[_View]], type[_View]]: def decorator(view: type[_View]) -> type[_View]: self._endpoints[view] = {"meths": {}} @@ -200,7 +204,7 @@ def decorator(view: type[_View]) -> type[_View]: methods = ((getattr(view, m), m) for m in map(str.lower, METH_ALL) if hasattr(view, m)) for func, method in methods: - ep_data = self._save_handler(func) + ep_data = self._save_handler(func, tags=list(tags)) self._endpoints[view]["meths"][method] = ep_data ta = ep_data.get("body") if ta: @@ -210,9 +214,9 @@ def decorator(view: type[_View]) -> type[_View]: return decorator - def api(self) -> Callable[[APIHandler[_Resp]], Callable[[web.Request], Awaitable[_Resp]]]: + def api(self, tags: Iterable[str] = ()) -> Callable[[APIHandler[_Resp]], Callable[[web.Request], Awaitable[_Resp]]]: def decorator(handler: APIHandler[_Resp]) -> Callable[[web.Request], Awaitable[_Resp]]: - ep_data = self._save_handler(handler) + ep_data = self._save_handler(handler, tags=list(tags)) ta = ep_data.get("body") if ta: @functools.wraps(handler) @@ -265,10 +269,15 @@ async def _on_startup(self, app: web.Application) -> None: operation: _OperationObject = {"operationId": route.handler.__name__} summary = endpoints.get("summary") desc = endpoints.get("desc") + tags = endpoints.get("tags") + if summary: operation["summary"] = summary if desc: operation["description"] = desc + if tags: + operation["tags"] = tags + path_data[method] = operation body = endpoints.get("body") diff --git a/tests/test_generator.py b/tests/test_generator.py index 0f6454d..ddc192a 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -247,3 +247,27 @@ async def put(self, body: NewPoll) -> APIResponse[Poll]: assert len(result) == 1 assert result[0]["loc"] == ["choices"] assert result[0]["type"] == "too_short" + + +async def test_tags(aiohttp_client: AiohttpClient) -> None: + schema_gen = SchemaGenerator() + + tags = ("a_tag", "b_tag") + + @schema_gen.api(tags=tags) + async def get_number( + request: web.Request, + ) -> APIResponse[tuple[Poll, ...], Literal[200]]: + """Number.""" + return APIResponse((POLL1,)) # pragma: no cover + + app = web.Application() + schema_gen.setup(app) + app.router.add_get("/number", get_number) + + client = await aiohttp_client(app) + async with client.get("/schema") as resp: + assert resp.ok + schema = await resp.json() + + assert schema["paths"]["/number"]["get"]["tags"] == ["a_tag", "b_tag"]