Skip to content

Commit

Permalink
Formatting and Linting: Stronger configuration for Ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Dec 24, 2024
1 parent 8fd737f commit 5ba8b06
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 47 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

# Ignore unused imports (F401), unused variables (F841), `print` statements (T201), and commented-out code (ERA001).
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=ERA001 $(PYTHON_FILES)

spell_check:
poetry run codespell --toml pyproject.toml
Expand Down
2 changes: 1 addition & 1 deletion examples/basic/document_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def import_mlb_teams_2012() -> None:
db = SQLDatabase.from_uri(CRATEDB_SQLALCHEMY_URL)
# TODO: Use new URL @ langchain-cratedb.
url = "https://github.com/crate-workbench/langchain/raw/cratedb/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql"
sql = requests.get(url).text
sql = requests.get(url, timeout=10).text
for statement in sqlparse.split(sql):
db.run(statement)
db.run("REFRESH TABLE mlb_teams_2012")
Expand Down
2 changes: 1 addition & 1 deletion examples/basic/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_documents() -> t.List[Document]:

# Load a document, and split it into chunks.
url = "https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt"
text = requests.get(url).text
text = requests.get(url, timeout=10).text
return text_splitter.create_documents([text])


Expand Down
37 changes: 17 additions & 20 deletions langchain_cratedb/vectorstores/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def delete(
# CrateDB: Calling ``delete`` must not raise an exception
# when deleting IDs that do not exist.
if self.EmbeddingStore is None:
return
return None
return super().delete(ids=ids, collection_only=collection_only, **kwargs)

def _ensure_storage(self) -> None:
Expand Down Expand Up @@ -367,7 +367,7 @@ def add_embeddings(

def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
"""Return docs and scores from results."""
docs = [
return [
(
Document(
id=str(result.EmbeddingStore.id),
Expand All @@ -378,7 +378,6 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
)
for result in results
]
return docs

def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
"""Get documents by ids."""
Expand Down Expand Up @@ -447,9 +446,9 @@ def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
filter: Optional[dict] = None, # noqa: A002
) -> List[Tuple[Document, float]]:
assert not self._async_engine, "This method must be called without async_mode"
assert not self._async_engine, "This method must be called without async_mode" # noqa: S101
results = self.__query_collection(embedding=embedding, k=k, filter=filter)

return self._results_to_docs_and_scores(results)
Expand All @@ -460,7 +459,7 @@ def max_marginal_relevance_search_with_score_by_vector(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, str]] = None, # noqa: A002
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance with score
Expand All @@ -486,7 +485,7 @@ def max_marginal_relevance_search_with_score_by_vector(
"""
import numpy as np

assert not self._async_engine, "This method must be called without async_mode"
assert not self._async_engine, "This method must be called without async_mode" # noqa: S101
results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)

embedding_list = [result.EmbeddingStore.embedding for result in results]
Expand All @@ -506,7 +505,7 @@ def __query_collection(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, str]] = None, # noqa: A002
) -> List[Any]:
"""Query the collection."""
self._init_models(embedding)
Expand All @@ -523,7 +522,7 @@ def _query_collection_multi(
collections: List[Any],
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, str]] = None, # noqa: A002
) -> List[Any]:
"""Query the collection."""
self._init_models(embedding)
Expand All @@ -547,7 +546,7 @@ def _query_collection_multi(
self.EmbeddingStore,
# TODO: Original pgvector code uses `self.distance_strategy`.
# CrateDB currently only supports EUCLIDEAN.
# self.distance_strategy(embedding).label("distance")
# self.distance_strategy(embedding).label("distance") # noqa: E501,ERA001
sa.func.vector_similarity(
self.EmbeddingStore.embedding,
# TODO: Just reference the `embedding` symbol here, don't
Expand Down Expand Up @@ -649,13 +648,13 @@ def _handle_field_filter(
# native is trusted input
native = COMPARISONS_TO_NATIVE[operator]
return self.EmbeddingStore.cmetadata[field].op(native)(filter_value)
elif operator == "$between":
if operator == "$between":
# Use AND with two comparisons
low, high = filter_value
lower_bound = self.EmbeddingStore.cmetadata[field].op(">=")(low)
upper_bound = self.EmbeddingStore.cmetadata[field].op("<=")(high)
return sa.and_(lower_bound, upper_bound)
elif operator in {"$in", "$nin", "$like", "$ilike"}:
if operator in {"$in", "$nin", "$like", "$ilike"}:
# We'll do force coercion to text
if operator in {"$in", "$nin"}:
for val in filter_value:
Expand All @@ -673,15 +672,14 @@ def _handle_field_filter(

if operator in {"$in"}:
return queried_field.in_([str(val) for val in filter_value])
elif operator in {"$nin"}:
if operator in {"$nin"}:
return ~queried_field.in_([str(val) for val in filter_value])
elif operator in {"$like"}:
if operator in {"$like"}:
return queried_field.like(filter_value)
elif operator in {"$ilike"}:
if operator in {"$ilike"}:
return queried_field.ilike(filter_value)
else:
raise NotImplementedError()
elif operator == "$exists":
raise NotImplementedError()
if operator == "$exists":
if not isinstance(filter_value, bool):
raise ValueError(
"Expected a boolean value for $exists "
Expand All @@ -691,5 +689,4 @@ def _handle_field_filter(
sa.func.any(sa.func.object_keys(self.EmbeddingStore.cmetadata))
)
return condition if filter_value else ~condition
else:
raise NotImplementedError()
raise NotImplementedError()
2 changes: 1 addition & 1 deletion langchain_cratedb/vectorstores/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, dimensions: Optional[int] = None):
Base: Any = declarative_base()

# Optional: Use a custom schema for the langchain tables.
# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any
# Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any # noqa: E501,ERA001

class BaseModel(Base):
"""Base model for the SQL stores."""
Expand Down
14 changes: 7 additions & 7 deletions langchain_cratedb/vectorstores/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
None, DBConnection, sa.Engine, sa.ext.asyncio.AsyncEngine, str
] = None,
embedding_length: Optional[int] = None,
collection_names: List[str] = [_LANGCHAIN_DEFAULT_COLLECTION_NAME],
collection_names: Optional[List[str]] = None,
collection_metadata: Optional[dict] = None,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
pre_delete_collection: bool = False,
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(
self.async_mode = async_mode
self.embedding_function = embeddings
self._embedding_length = embedding_length
self.collection_names = collection_names
self.collection_names = collection_names or [_LANGCHAIN_DEFAULT_COLLECTION_NAME]
self.collection_metadata = collection_metadata
self._distance_strategy = distance_strategy
self.pre_delete_collection = pre_delete_collection
Expand Down Expand Up @@ -150,9 +150,9 @@ def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
filter: Optional[dict] = None, # noqa: A002
) -> List[Tuple[Document, float]]:
assert not self._async_engine, "This method must be called without async_mode"
assert not self._async_engine, "This method must be called without async_mode" # noqa: S101
results = self.__query_collection(embedding=embedding, k=k, filter=filter)

return self._results_to_docs_and_scores(results)
Expand All @@ -163,7 +163,7 @@ def max_marginal_relevance_search_with_score_by_vector(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, str]] = None, # noqa: A002
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs selected using the maximal marginal relevance with score
Expand All @@ -189,7 +189,7 @@ def max_marginal_relevance_search_with_score_by_vector(
"""
import numpy as np

assert not self._async_engine, "This method must be called without async_mode"
assert not self._async_engine, "This method must be called without async_mode" # noqa: S101
results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)

embedding_list = [result.EmbeddingStore.embedding for result in results]
Expand All @@ -209,7 +209,7 @@ def __query_collection(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
filter: Optional[Dict[str, str]] = None, # noqa: A002
) -> List[Any]:
"""Query multiple collections."""
self._init_models(embedding)
Expand Down
44 changes: 41 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,49 @@ langchain-postgres = "==0.0.12"
sqlalchemy-cratedb = ">=0.40.1"

[tool.ruff.lint]
select = ["E", "F", "I", "T201"]

select = [
# Builtins
"A",
# Bugbear
"B",
# comprehensions
"C4",
# Pycodestyle
"E",
# eradicate
"ERA",
# Pyflakes
"F",
# isort
"I",
# pandas-vet
"PD",
# return
"RET",
# Bandit
"S",
# print
"T20",
"W",
# flake8-2020
"YTT",
]

[tool.ruff.lint.per-file-ignores]
"docs/*.ipynb" = ["F401", "F821", "T201"]
"examples/*.py" = ["F401", "F821", "T201"]
"docs/*.ipynb" = [
"F401",
"F821",
"T201",
"ERA001", # Found commented-out code
]
"examples/*.py" = [
"F401",
"F821",
"T20", # `print` found.
]
"tests/*" = ["S101"] # Use of `assert` detected
".github/scripts/*" = ["S101"] # Use of `assert` detected

[tool.coverage.run]
omit = [
Expand Down
13 changes: 6 additions & 7 deletions tests/feature/vectorstore/test_vector_cratedb_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_cratedb_with_filter_match(engine: sa.Engine) -> None:
pre_delete_collection=True,
)
# TODO: Original:
# assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501
# assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501,ERA001
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
prune_document_ids(output)
docs, scores = zip(*output)
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_cratedb_delete_collection(engine: sa.Engine, session: sa.orm.Session) -
collection_foo = store_foo.get_collection(session)
collection_bar = store_bar.get_collection(session)
if collection_foo is None or collection_bar is None:
assert False, "Expected CollectionStore objects but received None"
raise AssertionError("Expected CollectionStore objects but received None")
assert collection_foo.embeddings[0].cmetadata == {"document": "foo"}
assert collection_bar.embeddings[0].cmetadata == {"document": "bar"}

Expand All @@ -261,7 +261,7 @@ def test_cratedb_delete_collection(engine: sa.Engine, session: sa.orm.Session) -
collection_foo = store_foo.get_collection(session)
collection_bar = store_bar.get_collection(session)
if collection_bar is None:
assert False, "Expected CollectionStore object but received None"
raise AssertionError("Expected CollectionStore object but received None")
assert collection_foo is None
assert collection_bar.embeddings[0].cmetadata == {"document": "bar"}

Expand All @@ -287,10 +287,9 @@ def test_cratedb_collection_with_metadata(
)
collection = cratedb_vector.get_collection(session)
if collection is None:
assert False, "Expected a CollectionStore object but received None"
else:
assert collection.name == "test_collection"
assert collection.cmetadata == {"foo": "bar"}
raise AssertionError("Expected a CollectionStore object but received None")
assert collection.name == "test_collection"
assert collection.cmetadata == {"foo": "bar"}


def test_cratedb_collection_no_embedding_dimension(
Expand Down
2 changes: 1 addition & 1 deletion tests/feature/vectorstore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def ensure_collection(session: sa.orm.Session, name: str) -> None:
try:
session.execute(
sa.text(
f"INSERT INTO {COLLECTION_TABLE_NAME} (uuid, name, cmetadata) "
f"INSERT INTO {COLLECTION_TABLE_NAME} (uuid, name, cmetadata) " # noqa: S608
f"VALUES ('uuid-{name}', '{name}', {{}});"
)
)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_file(run_file: t.Callable, file: Path) -> None:
except openai.OpenAIError as ex:
if "The api_key client option must be set" not in str(ex):
raise
else:
raise pytest.skip(
"Skipping test because `OPENAI_API_KEY` is not defined"
) from ex
raise pytest.skip(
"Skipping test because `OPENAI_API_KEY` is not defined"
) from ex
2 changes: 1 addition & 1 deletion tests/util/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def run_module_function(
try:
mod = importlib.import_module(path.stem)
except ImportError as ex:
raise ImportError(f"Module not found at {filepath}: {ex}")
raise ImportError(f"Module not found at {filepath}: {ex}") from ex
fun = getattr(mod, entrypoint)

# Wrap the entrypoint function into a pytest test case, and run it.
Expand Down

0 comments on commit 5ba8b06

Please sign in to comment.