From 6a64762237f8165920e2f43850b5caef7eb4a24e Mon Sep 17 00:00:00 2001 From: Eric Bellm Date: Tue, 7 Jan 2025 21:35:36 -0800 Subject: [PATCH] fix mypy breakages --- .gitignore | 3 +++ src/tasso/cli.py | 6 +++--- src/tasso/main.py | 4 ++-- src/tasso/schema/classification.py | 2 +- src/tasso/schema/classification_run.py | 2 +- src/tasso/schema/user.py | 2 +- src/tasso/storage/base.py | 22 +++++++++++----------- src/tasso/webapp/app.py | 2 +- tests/conftest.py | 2 +- 9 files changed, 24 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 7af665b..9b11504 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ __pycache__/ *.py[cod] *$py.class +# temp files +*.swp + # C extensions *.so diff --git a/src/tasso/cli.py b/src/tasso/cli.py index 1f12924..ec07fe2 100644 --- a/src/tasso/cli.py +++ b/src/tasso/cli.py @@ -72,7 +72,7 @@ async def add_run(name: str) -> None: config.database_url, config.database_password ) - r = ClassificationRun(name=name) + r = ClassificationRun(name=name) # type: ignore[call-arg] async for db_session in db_session_dependency(): store = ClassificationRunStore(db_session) await store.add(r) @@ -104,7 +104,7 @@ async def add_subject(run_id: str, dia_source_id: int, uri: str) -> None: config.database_url, config.database_password ) - s = Subject(run_id=run_id, dia_source_id=dia_source_id, uri=uri) + s = Subject(run_id=run_id, dia_source_id=dia_source_id, uri=uri) # type: ignore[call-arg] async for db_session in db_session_dependency(): store = SubjectStore(db_session) await store.add(s) @@ -120,7 +120,7 @@ async def add_user(username: str) -> None: config.database_url, config.database_password ) - u = User(username=username) + u = User(username=username) # type: ignore[call-arg] async for db_session in db_session_dependency(): store = UserStore(db_session) await store.add(u) diff --git a/src/tasso/main.py b/src/tasso/main.py index c577553..fe2e2fc 100644 --- a/src/tasso/main.py +++ b/src/tasso/main.py @@ -102,11 +102,11 @@ async def validation_exception_handler( request: Request, exc: RequestValidationError | ValidationError ) -> JSONResponse: body = await request.body() - print(f"Request body: {body}") + print(f"Request body: {body}") # type: ignore[str-bytes-safe] print(exc) return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}), + content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}), # type: ignore[union-attr] ) diff --git a/src/tasso/schema/classification.py b/src/tasso/schema/classification.py index 0e668f3..49f6d93 100644 --- a/src/tasso/schema/classification.py +++ b/src/tasso/schema/classification.py @@ -28,4 +28,4 @@ class Classification(Base): flags: Mapped[int | None] time_labeled: Mapped[UtcDatetime] - subjects: Mapped[list["Subject"]] = relationship() # noqa: F821 + subjects: Mapped[list["Subject"]] = relationship() # type: ignore[name-defined] # noqa: F821 diff --git a/src/tasso/schema/classification_run.py b/src/tasso/schema/classification_run.py index c37c81a..47e60b2 100644 --- a/src/tasso/schema/classification_run.py +++ b/src/tasso/schema/classification_run.py @@ -25,4 +25,4 @@ class ClassificationRun(Base): time_stop: Mapped[UtcDatetime | None] max_classifications: Mapped[int] - subjects: Mapped[list["Subject"]] = relationship() # noqa: F821 + subjects: Mapped[list["Subject"]] = relationship() # type: ignore[name-defined] # noqa: F821 diff --git a/src/tasso/schema/user.py b/src/tasso/schema/user.py index b99d06a..7a3c77e 100644 --- a/src/tasso/schema/user.py +++ b/src/tasso/schema/user.py @@ -16,4 +16,4 @@ class User(Base): user_id: Mapped[str] = mapped_column(String(32), primary_key=True) username: Mapped[str] = mapped_column(String(64)) admin: Mapped[bool] - classifications: Mapped[list["Classification"]] = relationship() # noqa: F821 + classifications: Mapped[list["Classification"]] = relationship() # type: ignore[name-defined] # noqa: F821 diff --git a/src/tasso/storage/base.py b/src/tasso/storage/base.py index 679fad8..25c195d 100644 --- a/src/tasso/storage/base.py +++ b/src/tasso/storage/base.py @@ -30,7 +30,7 @@ def __init__( session: Annotated[ async_scoped_session, Depends(db_session_dependency) ], - model: ClassVar, + model: ClassVar, # type: ignore[misc] storage: ClassVar, primary_key: str, ) -> None: @@ -39,7 +39,7 @@ def __init__( self.storage = storage self.primary_key = primary_key - async def add(self, record: ClassVar) -> None: + async def add(self, record: ClassVar) -> None: # type: ignore[misc] """Add a new model record. Parameters @@ -51,7 +51,7 @@ async def add(self, record: ClassVar) -> None: async with self._session.begin(): self._session.add(new) - async def update(self, record: ClassVar) -> None: + async def update(self, record: ClassVar) -> None: # type: ignore[misc] """Update a model record identified by its primary key. Parameters @@ -71,7 +71,7 @@ async def update(self, record: ClassVar) -> None: async with self._session.begin(): await self._session.execute(stmt) - async def delete(self, record: ClassVar) -> bool: + async def delete(self, record: ClassVar) -> bool: # type: ignore[misc] """Delete a record. Parameters @@ -93,7 +93,7 @@ async def delete(self, record: ClassVar) -> bool: result = await self._session.execute(stmt) return result.rowcount > 0 - async def list(self) -> list[ClassVar]: # -> list[self.model]: + async def list(self) -> list[ClassVar]: # type: ignore[misc] """Return a list of model records. Returns @@ -107,7 +107,7 @@ async def list(self) -> list[ClassVar]: # -> list[self.model]: result = await self._session.scalars(stmt) return [self.model.model_validate(a) for a in result.all()] - async def get(self, value: str) -> ClassVar: + async def get(self, value: str) -> ClassVar: # type: ignore[misc] """Get a model record by primary key. Returns @@ -119,14 +119,14 @@ async def get(self, value: str) -> ClassVar: ) async with self._session.begin(): result = await self._session.execute(stmt) - value = result.one_or_none() - print(value) - if value is None: + row = result.one_or_none() + print(row) + if row is None: return None else: - return self.model.model_validate(value[0]) + return self.model.model_validate(row[0]) - async def search(self, key_value: dict[str, str]) -> ClassVar: + async def search(self, key_value: dict[str, str]) -> ClassVar: # type: ignore[misc] """Get model records matching key-value pairs. Parameters diff --git a/src/tasso/webapp/app.py b/src/tasso/webapp/app.py index 22212ff..943047c 100644 --- a/src/tasso/webapp/app.py +++ b/src/tasso/webapp/app.py @@ -27,7 +27,7 @@ async def lifespan(_: FastAPI) -> AsyncGenerator: config.database_url, config.database_password ) assert db_session_dependency._engine is not None # noqa: S101,SLF001 - db_session_dependency._engine.echo = config.database_echo # noqa: SLF001 + db_session_dependency._engine.echo = config.database_echo # type: ignore[attr-defined] # noqa: SLF001 # App runs here... yield diff --git a/tests/conftest.py b/tests/conftest.py index c3d6de1..3d72611 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ async def app() -> AsyncIterator[FastAPI]: async def client(app: FastAPI) -> AsyncIterator[AsyncClient]: """Return an ``httpx.AsyncClient`` configured to talk to the test app.""" async with AsyncClient( - transport=ASGITransport(app=app), # type: ignore[arg-type] + transport=ASGITransport(app=app), base_url="https://example.com/", ) as client: yield client