Skip to content

Commit

Permalink
feat: Switching to official Ollama python client
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Oct 13, 2024
1 parent 3ce0afd commit c7ee067
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 62 deletions.
70 changes: 40 additions & 30 deletions chromadb/test/ef/test_ollama_ef.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,43 @@
import os

import pytest
import httpx
from httpx import HTTPError, ConnectError

from chromadb.utils.embedding_functions import OllamaEmbeddingFunction


def test_ollama() -> None:
"""
To set up the Ollama server, follow instructions at: https://github.com/ollama/ollama?tab=readme-ov-file
Export the OLLAMA_SERVER_URL and OLLAMA_MODEL environment variables.
"""
if (
os.environ.get("OLLAMA_SERVER_URL") is None
or os.environ.get("OLLAMA_MODEL") is None
):
pytest.skip(
"OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test."
)
try:
response = httpx.get(os.environ.get("OLLAMA_SERVER_URL", ""))
# If the response was successful, no Exception will be raised
response.raise_for_status()
except (HTTPError, ConnectError):
pytest.skip("Ollama server not running. Skipping test.")
ef = OllamaEmbeddingFunction(
model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text",
url=f"{os.environ.get('OLLAMA_SERVER_URL')}/embeddings",
)

from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)


def test_ollama_default_model() -> None:
pytest.importorskip("ollama", reason="ollama not installed")
ef = OllamaEmbeddingFunction()
embeddings = ef(["Here is an article about llamas...", "this is another article"])
assert embeddings is not None
assert len(embeddings) == 2
assert all(len(e) == 384 for e in embeddings)


def test_ollama_unknown_model() -> None:
pytest.importorskip("ollama", reason="ollama not installed")
model_name = "unknown-model"
ef = OllamaEmbeddingFunction(model_name=model_name)
with pytest.raises(Exception) as e:
ef(["Here is an article about llamas...", "this is another article"])
assert f'model "{model_name}" not found' in str(e.value)


def test_ollama_wrong_base_url() -> None:
pytest.importorskip("ollama", reason="ollama not installed")
ef = OllamaEmbeddingFunction(url="http://localhost:12345")
with pytest.raises(Exception) as e:
ef(["Here is an article about llamas...", "this is another article"])
assert "Connection refused" in str(e.value)


def test_ollama_ask_user_to_install() -> None:
try:
from ollama import Client # noqa: F401
except ImportError:
pass
else:
pytest.skip("ollama python package is installed")
with pytest.raises(ValueError) as e:
OllamaEmbeddingFunction()
assert "The ollama python package is not installed" in str(e.value)
47 changes: 21 additions & 26 deletions chromadb/utils/embedding_functions/ollama_embedding_function.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
import logging
from typing import Union, cast

import httpx
from typing import Union, cast, Optional

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

logger = logging.getLogger(__name__)

DEFAULT_MODEL_NAME = "chroma/all-minilm-l6-v2-f32"


class OllamaEmbeddingFunction(EmbeddingFunction[Documents]):
"""
This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings).
"""

def __init__(self, url: str, model_name: str) -> None:
def __init__(
self, url: Optional[str] = None, model_name: Optional[str] = DEFAULT_MODEL_NAME
) -> None:
"""
Initialize the Ollama Embedding Function.
Args:
url (str): The URL of the Ollama Server.
model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models).
url (str): The Base URL of the Ollama Server (default: "http://localhost:11434").
model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see defaults to "chroma/all-minilm-l6-v2-f32", for available models see https://ollama.com/library).
"""
self._api_url = f"{url}"
self._model_name = model_name
self._session = httpx.Client()

try:
from ollama import Client
except ImportError:
raise ValueError(
"The ollama python package is not installed. Please install it with `pip install ollama`"
)
self._client = Client(host=url)
self._model_name = model_name or DEFAULT_MODEL_NAME

def __call__(self, input: Union[Documents, str]) -> Embeddings:
"""
Expand All @@ -36,23 +44,10 @@ def __call__(self, input: Union[Documents, str]) -> Embeddings:
Embeddings: The embeddings for the texts.
Example:
>>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text")
>>> ollama_ef = OllamaEmbeddingFunction()
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = ollama_ef(texts)
"""
# Call Ollama Server API for each document
texts = input if isinstance(input, list) else [input]
embeddings = [
self._session.post(
self._api_url, json={"model": self._model_name, "prompt": text}
).json()
for text in texts
]
return cast(
Embeddings,
[
embedding["embedding"]
for embedding in embeddings
if "embedding" in embedding
],
)
# Call Ollama client
response = self._client.embed(model=self._model_name, input=input)
return cast(Embeddings, response["embeddings"])
10 changes: 4 additions & 6 deletions docs/docs.trychroma.com/pages/integrations/ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@
title: Ollama Embeddings
---

Chroma provides a convenient wrapper around [Ollama](https://github.com/ollama/ollama)'
s [embeddings API](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). You can use
Chroma provides a convenient wrapper around [Ollama's](https://github.com/ollama/ollama) [python client](https://pypi.org/project/ollama/). You can use
the `OllamaEmbeddingFunction` embedding function to generate embeddings for your documents with
a [model](https://github.com/ollama/ollama?tab=readme-ov-file#model-library) of your choice.

{% tabs group="code-lang" %}
{% tab label="Python" %}

```python
import chromadb.utils.embedding_functions as embedding_functions
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction

ollama_ef = embedding_functions.OllamaEmbeddingFunction(
url="http://localhost:11434/api/embeddings",
model_name="llama2",
ollama_ef = OllamaEmbeddingFunction(
model_name="chroma/all-minilm-l6-v2-f32",
)

embeddings = ollama_ef(["This is my first text to embed",
Expand Down

0 comments on commit c7ee067

Please sign in to comment.