Skip to content

Commit

Permalink
Add tags parameter to decorator (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
diasdm authored Dec 13, 2024
1 parent 2097f2c commit 463621f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
21 changes: 15 additions & 6 deletions aiohttp_apischema/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from http import HTTPStatus
from pathlib import Path
from types import UnionType
from typing import Any, Literal, TypedDict, TypeGuard, TypeVar, cast, get_args, get_origin
from typing import Any, Iterable, Literal, TypedDict, TypeGuard, TypeVar, cast, get_args, get_origin

from aiohttp import web
from aiohttp.hdrs import METH_ALL
Expand Down Expand Up @@ -59,6 +59,7 @@ class _EndpointData(TypedDict, total=False):
desc: str
resps: dict[int, TypeAdapter[Any]]
summary: str
tags: list[str]

class _Endpoint(TypedDict, total=False):
desc: str
Expand All @@ -84,6 +85,7 @@ class _OperationObject(TypedDict, total=False):
requestBody: _RequestBodyObject
responses: dict[str, _ResponseObject]
summary: str
tags: list[str]

class _PathObject(TypedDict, total=False):
delete: _OperationObject
Expand Down Expand Up @@ -143,7 +145,7 @@ def __init__(self, info: Info | None = None):
info = {"title": "API", "version": "1.0"}
self._openapi: _OpenApi = {"openapi": "3.1.0", "info": info}

def _save_handler(self, handler: APIHandler[APIResponse[object, int]]) -> _EndpointData:
def _save_handler(self, handler: APIHandler[APIResponse[object, int]], tags: list[str]) -> _EndpointData:
ep_data: _EndpointData = {}
docs = inspect.getdoc(handler)
if docs:
Expand All @@ -154,6 +156,8 @@ def _save_handler(self, handler: APIHandler[APIResponse[object, int]]) -> _Endpo
ep_data["summary"] = summary
if desc:
ep_data["desc"] = desc
if tags:
ep_data["tags"] = tags

sig = inspect.signature(handler, eval_str=True)
params = iter(sig.parameters.values())
Expand Down Expand Up @@ -184,7 +188,7 @@ def _save_handler(self, handler: APIHandler[APIResponse[object, int]]) -> _Endpo

return ep_data

def api_view(self) -> Callable[[type[_View]], type[_View]]:
def api_view(self, tags: Iterable[str] = ()) -> Callable[[type[_View]], type[_View]]:
def decorator(view: type[_View]) -> type[_View]:
self._endpoints[view] = {"meths": {}}

Expand All @@ -200,7 +204,7 @@ def decorator(view: type[_View]) -> type[_View]:

methods = ((getattr(view, m), m) for m in map(str.lower, METH_ALL) if hasattr(view, m))
for func, method in methods:
ep_data = self._save_handler(func)
ep_data = self._save_handler(func, tags=list(tags))
self._endpoints[view]["meths"][method] = ep_data
ta = ep_data.get("body")
if ta:
Expand All @@ -210,9 +214,9 @@ def decorator(view: type[_View]) -> type[_View]:

return decorator

def api(self) -> Callable[[APIHandler[_Resp]], Callable[[web.Request], Awaitable[_Resp]]]:
def api(self, tags: Iterable[str] = ()) -> Callable[[APIHandler[_Resp]], Callable[[web.Request], Awaitable[_Resp]]]:
def decorator(handler: APIHandler[_Resp]) -> Callable[[web.Request], Awaitable[_Resp]]:
ep_data = self._save_handler(handler)
ep_data = self._save_handler(handler, tags=list(tags))
ta = ep_data.get("body")
if ta:
@functools.wraps(handler)
Expand Down Expand Up @@ -265,10 +269,15 @@ async def _on_startup(self, app: web.Application) -> None:
operation: _OperationObject = {"operationId": route.handler.__name__}
summary = endpoints.get("summary")
desc = endpoints.get("desc")
tags = endpoints.get("tags")

if summary:
operation["summary"] = summary
if desc:
operation["description"] = desc
if tags:
operation["tags"] = tags

path_data[method] = operation

body = endpoints.get("body")
Expand Down
24 changes: 24 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,27 @@ async def put(self, body: NewPoll) -> APIResponse[Poll]:
assert len(result) == 1
assert result[0]["loc"] == ["choices"]
assert result[0]["type"] == "too_short"


async def test_tags(aiohttp_client: AiohttpClient) -> None:
schema_gen = SchemaGenerator()

tags = ("a_tag", "b_tag")

@schema_gen.api(tags=tags)
async def get_number(
request: web.Request,
) -> APIResponse[tuple[Poll, ...], Literal[200]]:
"""Number."""
return APIResponse((POLL1,)) # pragma: no cover

app = web.Application()
schema_gen.setup(app)
app.router.add_get("/number", get_number)

client = await aiohttp_client(app)
async with client.get("/schema") as resp:
assert resp.ok
schema = await resp.json()

assert schema["paths"]["/number"]["get"]["tags"] == ["a_tag", "b_tag"]

0 comments on commit 463621f

Please sign in to comment.