diff --git a/Makefile b/Makefile index ca1d6de..da6778c 100644 --- a/Makefile +++ b/Makefile @@ -32,9 +32,10 @@ lint lint_diff lint_package lint_tests: [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) +# Ignore unused imports (F401), unused variables (F841), `print` statements (T201), and commented-out code (ERA001). format format_diff: [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) - [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=ERA001 $(PYTHON_FILES) spell_check: poetry run codespell --toml pyproject.toml diff --git a/examples/basic/document_loader.py b/examples/basic/document_loader.py index 017ce90..4f7e591 100644 --- a/examples/basic/document_loader.py +++ b/examples/basic/document_loader.py @@ -50,7 +50,7 @@ def import_mlb_teams_2012() -> None: db = SQLDatabase.from_uri(CRATEDB_SQLALCHEMY_URL) # TODO: Use new URL @ langchain-cratedb. url = "https://github.com/crate-workbench/langchain/raw/cratedb/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql" - sql = requests.get(url).text + sql = requests.get(url, timeout=10).text for statement in sqlparse.split(sql): db.run(statement) db.run("REFRESH TABLE mlb_teams_2012") diff --git a/examples/basic/vector_search.py b/examples/basic/vector_search.py index e5004ac..5332548 100644 --- a/examples/basic/vector_search.py +++ b/examples/basic/vector_search.py @@ -52,7 +52,7 @@ def get_documents() -> t.List[Document]: # Load a document, and split it into chunks. url = "https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt" - text = requests.get(url).text + text = requests.get(url, timeout=10).text return text_splitter.create_documents([text]) diff --git a/langchain_cratedb/vectorstores/main.py b/langchain_cratedb/vectorstores/main.py index c96d0eb..fa3ed3c 100644 --- a/langchain_cratedb/vectorstores/main.py +++ b/langchain_cratedb/vectorstores/main.py @@ -295,7 +295,7 @@ def delete( # CrateDB: Calling ``delete`` must not raise an exception # when deleting IDs that do not exist. if self.EmbeddingStore is None: - return + return None return super().delete(ids=ids, collection_only=collection_only, **kwargs) def _ensure_storage(self) -> None: @@ -367,7 +367,7 @@ def add_embeddings( def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: """Return docs and scores from results.""" - docs = [ + return [ ( Document( id=str(result.EmbeddingStore.id), @@ -378,7 +378,6 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa ) for result in results ] - return docs def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: """Get documents by ids.""" @@ -447,9 +446,9 @@ def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, - filter: Optional[dict] = None, + filter: Optional[dict] = None, # noqa: A002 ) -> List[Tuple[Document, float]]: - assert not self._async_engine, "This method must be called without async_mode" + assert not self._async_engine, "This method must be called without async_mode" # noqa: S101 results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) @@ -460,7 +459,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, str]] = None, # noqa: A002 **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score @@ -486,7 +485,7 @@ def max_marginal_relevance_search_with_score_by_vector( """ import numpy as np - assert not self._async_engine, "This method must be called without async_mode" + assert not self._async_engine, "This method must be called without async_mode" # noqa: S101 results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -506,7 +505,7 @@ def __query_collection( self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, str]] = None, # noqa: A002 ) -> List[Any]: """Query the collection.""" self._init_models(embedding) @@ -523,7 +522,7 @@ def _query_collection_multi( collections: List[Any], embedding: List[float], k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, str]] = None, # noqa: A002 ) -> List[Any]: """Query the collection.""" self._init_models(embedding) @@ -547,7 +546,7 @@ def _query_collection_multi( self.EmbeddingStore, # TODO: Original pgvector code uses `self.distance_strategy`. # CrateDB currently only supports EUCLIDEAN. - # self.distance_strategy(embedding).label("distance") + # self.distance_strategy(embedding).label("distance") # noqa: E501,ERA001 sa.func.vector_similarity( self.EmbeddingStore.embedding, # TODO: Just reference the `embedding` symbol here, don't @@ -649,13 +648,13 @@ def _handle_field_filter( # native is trusted input native = COMPARISONS_TO_NATIVE[operator] return self.EmbeddingStore.cmetadata[field].op(native)(filter_value) - elif operator == "$between": + if operator == "$between": # Use AND with two comparisons low, high = filter_value lower_bound = self.EmbeddingStore.cmetadata[field].op(">=")(low) upper_bound = self.EmbeddingStore.cmetadata[field].op("<=")(high) return sa.and_(lower_bound, upper_bound) - elif operator in {"$in", "$nin", "$like", "$ilike"}: + if operator in {"$in", "$nin", "$like", "$ilike"}: # We'll do force coercion to text if operator in {"$in", "$nin"}: for val in filter_value: @@ -673,15 +672,14 @@ def _handle_field_filter( if operator in {"$in"}: return queried_field.in_([str(val) for val in filter_value]) - elif operator in {"$nin"}: + if operator in {"$nin"}: return ~queried_field.in_([str(val) for val in filter_value]) - elif operator in {"$like"}: + if operator in {"$like"}: return queried_field.like(filter_value) - elif operator in {"$ilike"}: + if operator in {"$ilike"}: return queried_field.ilike(filter_value) - else: - raise NotImplementedError() - elif operator == "$exists": + raise NotImplementedError() + if operator == "$exists": if not isinstance(filter_value, bool): raise ValueError( "Expected a boolean value for $exists " @@ -691,5 +689,4 @@ def _handle_field_filter( sa.func.any(sa.func.object_keys(self.EmbeddingStore.cmetadata)) ) return condition if filter_value else ~condition - else: - raise NotImplementedError() + raise NotImplementedError() diff --git a/langchain_cratedb/vectorstores/model.py b/langchain_cratedb/vectorstores/model.py index cd30a14..22e328c 100644 --- a/langchain_cratedb/vectorstores/model.py +++ b/langchain_cratedb/vectorstores/model.py @@ -25,7 +25,7 @@ def __init__(self, dimensions: Optional[int] = None): Base: Any = declarative_base() # Optional: Use a custom schema for the langchain tables. - # Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any + # Base = declarative_base(metadata=MetaData(schema="langchain")) # type: Any # noqa: E501,ERA001 class BaseModel(Base): """Base model for the SQL stores.""" diff --git a/langchain_cratedb/vectorstores/multi.py b/langchain_cratedb/vectorstores/multi.py index cc36deb..550ef8e 100644 --- a/langchain_cratedb/vectorstores/multi.py +++ b/langchain_cratedb/vectorstores/multi.py @@ -51,7 +51,7 @@ def __init__( None, DBConnection, sa.Engine, sa.ext.asyncio.AsyncEngine, str ] = None, embedding_length: Optional[int] = None, - collection_names: List[str] = [_LANGCHAIN_DEFAULT_COLLECTION_NAME], + collection_names: Optional[List[str]] = None, collection_metadata: Optional[dict] = None, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, @@ -90,7 +90,7 @@ def __init__( self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length - self.collection_names = collection_names + self.collection_names = collection_names or [_LANGCHAIN_DEFAULT_COLLECTION_NAME] self.collection_metadata = collection_metadata self._distance_strategy = distance_strategy self.pre_delete_collection = pre_delete_collection @@ -150,9 +150,9 @@ def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, - filter: Optional[dict] = None, + filter: Optional[dict] = None, # noqa: A002 ) -> List[Tuple[Document, float]]: - assert not self._async_engine, "This method must be called without async_mode" + assert not self._async_engine, "This method must be called without async_mode" # noqa: S101 results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) @@ -163,7 +163,7 @@ def max_marginal_relevance_search_with_score_by_vector( k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, str]] = None, # noqa: A002 **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs selected using the maximal marginal relevance with score @@ -189,7 +189,7 @@ def max_marginal_relevance_search_with_score_by_vector( """ import numpy as np - assert not self._async_engine, "This method must be called without async_mode" + assert not self._async_engine, "This method must be called without async_mode" # noqa: S101 results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -209,7 +209,7 @@ def __query_collection( self, embedding: List[float], k: int = 4, - filter: Optional[Dict[str, str]] = None, + filter: Optional[Dict[str, str]] = None, # noqa: A002 ) -> List[Any]: """Query multiple collections.""" self._init_models(embedding) diff --git a/pyproject.toml b/pyproject.toml index aa74189..513155c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,11 +96,49 @@ langchain-postgres = "==0.0.12" sqlalchemy-cratedb = ">=0.40.1" [tool.ruff.lint] -select = ["E", "F", "I", "T201"] + +select = [ + # Builtins + "A", + # Bugbear + "B", + # comprehensions + "C4", + # Pycodestyle + "E", + # eradicate + "ERA", + # Pyflakes + "F", + # isort + "I", + # pandas-vet + "PD", + # return + "RET", + # Bandit + "S", + # print + "T20", + "W", + # flake8-2020 + "YTT", +] [tool.ruff.lint.per-file-ignores] -"docs/*.ipynb" = ["F401", "F821", "T201"] -"examples/*.py" = ["F401", "F821", "T201"] +"docs/*.ipynb" = [ + "F401", + "F821", + "T201", + "ERA001", # Found commented-out code +] +"examples/*.py" = [ + "F401", + "F821", + "T20", # `print` found. +] +"tests/*" = ["S101"] # Use of `assert` detected +".github/scripts/*" = ["S101"] # Use of `assert` detected [tool.coverage.run] omit = [ diff --git a/tests/feature/vectorstore/test_vector_cratedb_main.py b/tests/feature/vectorstore/test_vector_cratedb_main.py index 28c245c..8e9f55e 100644 --- a/tests/feature/vectorstore/test_vector_cratedb_main.py +++ b/tests/feature/vectorstore/test_vector_cratedb_main.py @@ -174,7 +174,7 @@ def test_cratedb_with_filter_match(engine: sa.Engine) -> None: pre_delete_collection=True, ) # TODO: Original: - # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501 + # assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501,ERA001 output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"}) prune_document_ids(output) docs, scores = zip(*output) @@ -250,7 +250,7 @@ def test_cratedb_delete_collection(engine: sa.Engine, session: sa.orm.Session) - collection_foo = store_foo.get_collection(session) collection_bar = store_bar.get_collection(session) if collection_foo is None or collection_bar is None: - assert False, "Expected CollectionStore objects but received None" + raise AssertionError("Expected CollectionStore objects but received None") assert collection_foo.embeddings[0].cmetadata == {"document": "foo"} assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} @@ -261,7 +261,7 @@ def test_cratedb_delete_collection(engine: sa.Engine, session: sa.orm.Session) - collection_foo = store_foo.get_collection(session) collection_bar = store_bar.get_collection(session) if collection_bar is None: - assert False, "Expected CollectionStore object but received None" + raise AssertionError("Expected CollectionStore object but received None") assert collection_foo is None assert collection_bar.embeddings[0].cmetadata == {"document": "bar"} @@ -287,10 +287,9 @@ def test_cratedb_collection_with_metadata( ) collection = cratedb_vector.get_collection(session) if collection is None: - assert False, "Expected a CollectionStore object but received None" - else: - assert collection.name == "test_collection" - assert collection.cmetadata == {"foo": "bar"} + raise AssertionError("Expected a CollectionStore object but received None") + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"} def test_cratedb_collection_no_embedding_dimension( diff --git a/tests/feature/vectorstore/util.py b/tests/feature/vectorstore/util.py index 92f13ce..adbffa4 100644 --- a/tests/feature/vectorstore/util.py +++ b/tests/feature/vectorstore/util.py @@ -42,7 +42,7 @@ def ensure_collection(session: sa.orm.Session, name: str) -> None: try: session.execute( sa.text( - f"INSERT INTO {COLLECTION_TABLE_NAME} (uuid, name, cmetadata) " + f"INSERT INTO {COLLECTION_TABLE_NAME} (uuid, name, cmetadata) " # noqa: S608 f"VALUES ('uuid-{name}', '{name}', {{}});" ) ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 742bb84..404934e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -44,7 +44,6 @@ def test_file(run_file: t.Callable, file: Path) -> None: except openai.OpenAIError as ex: if "The api_key client option must be set" not in str(ex): raise - else: - raise pytest.skip( - "Skipping test because `OPENAI_API_KEY` is not defined" - ) from ex + raise pytest.skip( + "Skipping test because `OPENAI_API_KEY` is not defined" + ) from ex diff --git a/tests/util/pytest.py b/tests/util/pytest.py index 79ea9c8..cf75ae4 100644 --- a/tests/util/pytest.py +++ b/tests/util/pytest.py @@ -30,7 +30,7 @@ def run_module_function( try: mod = importlib.import_module(path.stem) except ImportError as ex: - raise ImportError(f"Module not found at {filepath}: {ex}") + raise ImportError(f"Module not found at {filepath}: {ex}") from ex fun = getattr(mod, entrypoint) # Wrap the entrypoint function into a pytest test case, and run it.