-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tests: Add and validate programs in
examples
folder
It's the traditional trio `chat_history.py`, `document_loader.py`, and `vector_search.py`. They have been slotted into the `cratedb-examples` repository beforehand.
- Loading branch information
Showing
7 changed files
with
294 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
Demonstrate chat history / conversational memory with CrateDB. | ||
Synopsis:: | ||
# Install prerequisites. | ||
pip install --upgrade langchain-cratedb | ||
# Start database. | ||
docker run --rm -it --publish=4200:4200 crate/crate:nightly | ||
# Optionally set environment variable to configure CrateDB connection URL. | ||
export CRATEDB_SQLALCHEMY_URL="crate://crate@localhost/?schema=doc" | ||
# Run program. | ||
python examples/basic/chat_history.py | ||
""" # noqa: E501 | ||
# /// script | ||
# requires-python = ">=3.9" | ||
# dependencies = [ | ||
# "langchain-cratedb", | ||
# ] | ||
# /// | ||
|
||
import os | ||
from pprint import pprint | ||
|
||
from langchain_cratedb import CrateDBChatMessageHistory | ||
|
||
CRATEDB_SQLALCHEMY_URL = os.environ.get( | ||
"CRATEDB_SQLALCHEMY_URL", "crate://crate@localhost/?schema=testdrive" | ||
) | ||
|
||
|
||
def main() -> None: | ||
chat_history = CrateDBChatMessageHistory( | ||
session_id="test_session", | ||
connection=CRATEDB_SQLALCHEMY_URL, | ||
) | ||
chat_history.add_user_message("Hello") | ||
chat_history.add_ai_message("Hi") | ||
pprint(chat_history.messages) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
""" | ||
Exercise the LangChain/CrateDB document loader. | ||
How to use the SQL document loader, based on SQLAlchemy. | ||
The example uses the canonical `mlb_teams_2012.csv`, | ||
converted to SQL, see `mlb_teams_2012.sql`. | ||
Synopsis:: | ||
# Install prerequisites. | ||
pip install --upgrade langchain-cratedb langchain-community | ||
# Start database. | ||
docker run --rm -it --publish=4200:4200 crate/crate:nightly | ||
# Optionally set environment variable to configure CrateDB connection URL. | ||
export CRATEDB_SQLALCHEMY_URL="crate://crate@localhost/?schema=doc" | ||
# Run program. | ||
python examples/basic/document_loader.py | ||
""" # noqa: E501 | ||
# /// script | ||
# requires-python = ">=3.9" | ||
# dependencies = [ | ||
# "langchain-cratedb", | ||
# ] | ||
# /// | ||
|
||
import os | ||
from pprint import pprint | ||
|
||
import requests | ||
import sqlparse | ||
from langchain_community.utilities import SQLDatabase | ||
|
||
from langchain_cratedb import CrateDBLoader | ||
|
||
CRATEDB_SQLALCHEMY_URL = os.environ.get( | ||
"CRATEDB_SQLALCHEMY_URL", "crate://crate@localhost/?schema=testdrive" | ||
) | ||
|
||
|
||
def import_mlb_teams_2012() -> None: | ||
""" | ||
Import data into database table `mlb_teams_2012`. | ||
TODO: Refactor into general purpose package. | ||
""" | ||
db = SQLDatabase.from_uri(CRATEDB_SQLALCHEMY_URL) | ||
# TODO: Use new URL @ langchain-cratedb. | ||
url = "https://github.com/crate-workbench/langchain/raw/cratedb/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql" | ||
sql = requests.get(url).text | ||
for statement in sqlparse.split(sql): | ||
db.run(statement) | ||
db.run("REFRESH TABLE mlb_teams_2012") | ||
|
||
|
||
def main() -> None: | ||
# Load data. | ||
import_mlb_teams_2012() | ||
|
||
db = SQLDatabase.from_uri(CRATEDB_SQLALCHEMY_URL) | ||
|
||
# Query data. | ||
loader = CrateDBLoader( | ||
query="SELECT * FROM mlb_teams_2012 LIMIT 3;", | ||
db=db, | ||
include_rownum_into_metadata=True, | ||
) | ||
docs = loader.load() | ||
pprint(docs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
""" | ||
Use CrateDB Vector Search with OpenAI embeddings. | ||
As input data, the example uses the canonical `state_of_the_union.txt`. | ||
Synopsis:: | ||
# Install prerequisites. | ||
pip install --upgrade langchain-cratedb langchain-openai | ||
# Start database. | ||
docker run --rm -it --publish=4200:4200 crate/crate:nightly | ||
# Configure: Set environment variables to configure OpenAI authentication token | ||
# and optionally CrateDB connection URL. | ||
export OPENAI_API_KEY="<API KEY>" | ||
export CRATEDB_SQLALCHEMY_URL="crate://crate@localhost/?schema=doc" | ||
# Run program. | ||
python examples/basic/vector_search.py | ||
""" # noqa: E501 | ||
# /// script | ||
# requires-python = ">=3.9" | ||
# dependencies = [ | ||
# "langchain-openai", | ||
# "langchain-cratedb", | ||
# ] | ||
# /// | ||
|
||
import os | ||
import typing as t | ||
|
||
import requests | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain_core.documents import Document | ||
from langchain_openai import OpenAIEmbeddings | ||
|
||
from langchain_cratedb import CrateDBVectorStore | ||
|
||
CRATEDB_SQLALCHEMY_URL = os.environ.get( | ||
"CRATEDB_SQLALCHEMY_URL", "crate://crate@localhost/?schema=testdrive" | ||
) | ||
|
||
|
||
def get_documents() -> t.List[Document]: | ||
""" | ||
Acquire data, return as LangChain documents. | ||
""" | ||
|
||
# Define text splitter. | ||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | ||
|
||
# Load a document, and split it into chunks. | ||
url = "https://github.com/langchain-ai/langchain/raw/v0.0.325/docs/docs/modules/state_of_the_union.txt" | ||
text = requests.get(url).text | ||
return text_splitter.create_documents([text]) | ||
|
||
|
||
def main() -> None: | ||
# Acquire documents. | ||
documents = get_documents() | ||
|
||
# Embed each chunk, and load them into the vector store. | ||
vector_store = CrateDBVectorStore.from_documents( | ||
documents, OpenAIEmbeddings(), connection=CRATEDB_SQLALCHEMY_URL | ||
) | ||
|
||
# Invoke a query, and display the first result. | ||
query = "What did the president say about Ketanji Brown Jackson" | ||
docs = vector_store.similarity_search(query) | ||
print(docs[0].page_content) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,50 @@ | ||
import typing as t | ||
from pathlib import Path | ||
|
||
import openai | ||
import pytest | ||
from _pytest.capture import CaptureFixture | ||
from _pytest.fixtures import FixtureRequest | ||
from _pytest.python import Metafunc | ||
|
||
from tests.util.pytest import run_module_function | ||
from tests.util.python import generate_file_tests, list_python_files | ||
|
||
ROOT = Path(__file__).parent.parent | ||
EXAMPLES_FOLDER = ROOT / "examples" | ||
|
||
|
||
def test_dummy(request: FixtureRequest, capsys: CaptureFixture) -> None: | ||
outcome = run_module_function( | ||
request=request, filepath=EXAMPLES_FOLDER / "dummy.py" | ||
) | ||
assert isinstance(outcome[0], Path) | ||
assert outcome[0].name == "dummy.py" | ||
assert outcome[1] == 0 | ||
assert outcome[2] == "test_dummy.main" | ||
# Configure example programs to skip testing. | ||
SKIP_FILES: t.List[str] = [] | ||
|
||
|
||
def test_dummy(run_file: t.Callable, capsys: CaptureFixture) -> None: | ||
run_file(EXAMPLES_FOLDER / "dummy.py") | ||
out, err = capsys.readouterr() | ||
assert out == "Hallo, Räuber Hotzenplotz.\n" | ||
|
||
|
||
def pytest_generate_tests(metafunc: Metafunc) -> None: | ||
""" | ||
Generate pytest test case per example program. | ||
""" | ||
examples_root = EXAMPLES_FOLDER | ||
file_paths = list_python_files(examples_root) | ||
generate_file_tests(metafunc, file_paths=file_paths) | ||
|
||
|
||
def test_file(run_file: t.Callable, file: Path) -> None: | ||
""" | ||
Execute Python code, one test case per .py file. | ||
Skip test cases that trip when no OpenAI API key is configured. | ||
""" | ||
if file.name in SKIP_FILES: | ||
raise pytest.skip(f"FIXME: Skipping file: {file.name}") | ||
try: | ||
run_file(file) | ||
except openai.OpenAIError as ex: | ||
if "The api_key client option must be set" not in str(ex): | ||
raise | ||
else: | ||
raise pytest.skip( | ||
"Skipping test because `OPENAI_API_KEY` is not defined" | ||
) from ex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# TODO: Vendored from `pueblo.testing.snippet`. | ||
# Reason: It trips with an import error after just installing it. | ||
# ImportError: cannot import name 'FixtureDef' from 'pytest' | ||
# https://github.com/pyveci/pueblo/issues/129 | ||
import typing as t | ||
from os import PathLike | ||
from pathlib import Path | ||
|
||
import pytest | ||
from _pytest.fixtures import FixtureRequest | ||
from _pytest.python import Metafunc | ||
|
||
from tests.util.pytest import run_module_function | ||
|
||
|
||
def list_python_files(path: Path) -> t.List[Path]: | ||
""" | ||
Enumerate all Python files found in given directory, recursively. | ||
""" | ||
return list(path.rglob("*.py")) | ||
|
||
|
||
def generate_file_tests( | ||
metafunc: Metafunc, file_paths: t.List[Path], fixture_name: str = "file" | ||
) -> None: | ||
""" | ||
Generate test cases for Python example programs. | ||
""" | ||
if fixture_name in metafunc.fixturenames: | ||
names = [nb_path.name for nb_path in file_paths] | ||
metafunc.parametrize(fixture_name, file_paths, ids=names) | ||
|
||
|
||
@pytest.fixture | ||
def run_file(request: FixtureRequest) -> t.Callable: | ||
""" | ||
Invoke Python example programs as pytest test cases. This fixture is a factory | ||
that returns a function that can be used for invocation. | ||
TODO: Wrap outcome into a better shape. | ||
At least, use a dictionary, optimally an object. | ||
""" | ||
|
||
def _runner( | ||
path: Path, | ||
) -> t.Tuple[t.Union[PathLike[str], str], t.Union[int, None], str]: | ||
outcome = run_module_function(request=request, filepath=path) | ||
assert isinstance(outcome[0], Path) | ||
return outcome | ||
|
||
return _runner |