Skip to content

Commit

Permalink
fix mypy breakages
Browse files Browse the repository at this point in the history
  • Loading branch information
ebellm committed Jan 8, 2025
1 parent b43ecd2 commit 6a64762
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ __pycache__/
*.py[cod]
*$py.class

# temp files
*.swp

# C extensions
*.so

Expand Down
6 changes: 3 additions & 3 deletions src/tasso/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def add_run(name: str) -> None:
config.database_url, config.database_password
)

r = ClassificationRun(name=name)
r = ClassificationRun(name=name) # type: ignore[call-arg]
async for db_session in db_session_dependency():
store = ClassificationRunStore(db_session)
await store.add(r)
Expand Down Expand Up @@ -104,7 +104,7 @@ async def add_subject(run_id: str, dia_source_id: int, uri: str) -> None:
config.database_url, config.database_password
)

s = Subject(run_id=run_id, dia_source_id=dia_source_id, uri=uri)
s = Subject(run_id=run_id, dia_source_id=dia_source_id, uri=uri) # type: ignore[call-arg]
async for db_session in db_session_dependency():
store = SubjectStore(db_session)
await store.add(s)
Expand All @@ -120,7 +120,7 @@ async def add_user(username: str) -> None:
config.database_url, config.database_password
)

u = User(username=username)
u = User(username=username) # type: ignore[call-arg]
async for db_session in db_session_dependency():
store = UserStore(db_session)
await store.add(u)
Expand Down
4 changes: 2 additions & 2 deletions src/tasso/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ async def validation_exception_handler(
request: Request, exc: RequestValidationError | ValidationError
) -> JSONResponse:
body = await request.body()
print(f"Request body: {body}")
print(f"Request body: {body}") # type: ignore[str-bytes-safe]
print(exc)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}), # type: ignore[union-attr]
)


Expand Down
2 changes: 1 addition & 1 deletion src/tasso/schema/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ class Classification(Base):
flags: Mapped[int | None]
time_labeled: Mapped[UtcDatetime]

subjects: Mapped[list["Subject"]] = relationship() # noqa: F821
subjects: Mapped[list["Subject"]] = relationship() # type: ignore[name-defined] # noqa: F821
2 changes: 1 addition & 1 deletion src/tasso/schema/classification_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ class ClassificationRun(Base):
time_stop: Mapped[UtcDatetime | None]
max_classifications: Mapped[int]

subjects: Mapped[list["Subject"]] = relationship() # noqa: F821
subjects: Mapped[list["Subject"]] = relationship() # type: ignore[name-defined] # noqa: F821
2 changes: 1 addition & 1 deletion src/tasso/schema/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ class User(Base):
user_id: Mapped[str] = mapped_column(String(32), primary_key=True)
username: Mapped[str] = mapped_column(String(64))
admin: Mapped[bool]
classifications: Mapped[list["Classification"]] = relationship() # noqa: F821
classifications: Mapped[list["Classification"]] = relationship() # type: ignore[name-defined] # noqa: F821
22 changes: 11 additions & 11 deletions src/tasso/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
session: Annotated[
async_scoped_session, Depends(db_session_dependency)
],
model: ClassVar,
model: ClassVar, # type: ignore[misc]
storage: ClassVar,
primary_key: str,
) -> None:
Expand All @@ -39,7 +39,7 @@ def __init__(
self.storage = storage
self.primary_key = primary_key

async def add(self, record: ClassVar) -> None:
async def add(self, record: ClassVar) -> None: # type: ignore[misc]
"""Add a new model record.
Parameters
Expand All @@ -51,7 +51,7 @@ async def add(self, record: ClassVar) -> None:
async with self._session.begin():
self._session.add(new)

async def update(self, record: ClassVar) -> None:
async def update(self, record: ClassVar) -> None: # type: ignore[misc]
"""Update a model record identified by its primary key.
Parameters
Expand All @@ -71,7 +71,7 @@ async def update(self, record: ClassVar) -> None:
async with self._session.begin():
await self._session.execute(stmt)

async def delete(self, record: ClassVar) -> bool:
async def delete(self, record: ClassVar) -> bool: # type: ignore[misc]
"""Delete a record.
Parameters
Expand All @@ -93,7 +93,7 @@ async def delete(self, record: ClassVar) -> bool:
result = await self._session.execute(stmt)
return result.rowcount > 0

async def list(self) -> list[ClassVar]: # -> list[self.model]:
async def list(self) -> list[ClassVar]: # type: ignore[misc]
"""Return a list of model records.
Returns
Expand All @@ -107,7 +107,7 @@ async def list(self) -> list[ClassVar]: # -> list[self.model]:
result = await self._session.scalars(stmt)
return [self.model.model_validate(a) for a in result.all()]

async def get(self, value: str) -> ClassVar:
async def get(self, value: str) -> ClassVar: # type: ignore[misc]
"""Get a model record by primary key.
Returns
Expand All @@ -119,14 +119,14 @@ async def get(self, value: str) -> ClassVar:
)
async with self._session.begin():
result = await self._session.execute(stmt)
value = result.one_or_none()
print(value)
if value is None:
row = result.one_or_none()
print(row)
if row is None:
return None
else:
return self.model.model_validate(value[0])
return self.model.model_validate(row[0])

async def search(self, key_value: dict[str, str]) -> ClassVar:
async def search(self, key_value: dict[str, str]) -> ClassVar: # type: ignore[misc]
"""Get model records matching key-value pairs.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion src/tasso/webapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def lifespan(_: FastAPI) -> AsyncGenerator:
config.database_url, config.database_password
)
assert db_session_dependency._engine is not None # noqa: S101,SLF001
db_session_dependency._engine.echo = config.database_echo # noqa: SLF001
db_session_dependency._engine.echo = config.database_echo # type: ignore[attr-defined] # noqa: SLF001

# App runs here...
yield
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def app() -> AsyncIterator[FastAPI]:
async def client(app: FastAPI) -> AsyncIterator[AsyncClient]:
"""Return an ``httpx.AsyncClient`` configured to talk to the test app."""
async with AsyncClient(
transport=ASGITransport(app=app), # type: ignore[arg-type]
transport=ASGITransport(app=app),
base_url="https://example.com/",
) as client:
yield client

0 comments on commit 6a64762

Please sign in to comment.