-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Switching to official Ollama python client
- Loading branch information
Showing
3 changed files
with
65 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,43 @@ | ||
import os | ||
|
||
import pytest | ||
import httpx | ||
from httpx import HTTPError, ConnectError | ||
|
||
from chromadb.utils.embedding_functions import OllamaEmbeddingFunction | ||
|
||
|
||
def test_ollama() -> None: | ||
""" | ||
To set up the Ollama server, follow instructions at: https://github.com/ollama/ollama?tab=readme-ov-file | ||
Export the OLLAMA_SERVER_URL and OLLAMA_MODEL environment variables. | ||
""" | ||
if ( | ||
os.environ.get("OLLAMA_SERVER_URL") is None | ||
or os.environ.get("OLLAMA_MODEL") is None | ||
): | ||
pytest.skip( | ||
"OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test." | ||
) | ||
try: | ||
response = httpx.get(os.environ.get("OLLAMA_SERVER_URL", "")) | ||
# If the response was successful, no Exception will be raised | ||
response.raise_for_status() | ||
except (HTTPError, ConnectError): | ||
pytest.skip("Ollama server not running. Skipping test.") | ||
ef = OllamaEmbeddingFunction( | ||
model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text", | ||
url=f"{os.environ.get('OLLAMA_SERVER_URL')}/embeddings", | ||
) | ||
|
||
from chromadb.utils.embedding_functions.ollama_embedding_function import ( | ||
OllamaEmbeddingFunction, | ||
) | ||
|
||
|
||
def test_ollama_default_model() -> None: | ||
pytest.importorskip("ollama", reason="ollama not installed") | ||
ef = OllamaEmbeddingFunction() | ||
embeddings = ef(["Here is an article about llamas...", "this is another article"]) | ||
assert embeddings is not None | ||
assert len(embeddings) == 2 | ||
assert all(len(e) == 384 for e in embeddings) | ||
|
||
|
||
def test_ollama_unknown_model() -> None: | ||
pytest.importorskip("ollama", reason="ollama not installed") | ||
model_name = "unknown-model" | ||
ef = OllamaEmbeddingFunction(model_name=model_name) | ||
with pytest.raises(Exception) as e: | ||
ef(["Here is an article about llamas...", "this is another article"]) | ||
assert f'model "{model_name}" not found' in str(e.value) | ||
|
||
|
||
def test_ollama_wrong_base_url() -> None: | ||
pytest.importorskip("ollama", reason="ollama not installed") | ||
ef = OllamaEmbeddingFunction(url="http://localhost:12345") | ||
with pytest.raises(Exception) as e: | ||
ef(["Here is an article about llamas...", "this is another article"]) | ||
assert "Connection refused" in str(e.value) | ||
|
||
|
||
def test_ollama_ask_user_to_install() -> None: | ||
try: | ||
from ollama import Client # noqa: F401 | ||
except ImportError: | ||
pass | ||
else: | ||
pytest.skip("ollama python package is installed") | ||
with pytest.raises(ValueError) as e: | ||
OllamaEmbeddingFunction() | ||
assert "The ollama python package is not installed" in str(e.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters