Skip to content

Commit

Permalink
Fix: Correct environment variable handling for limited Azure OpenAI G…
Browse files Browse the repository at this point in the history
…PT and embedding access by properly differentiating endpoints.
  • Loading branch information
CatilonyZhang committed Dec 17, 2024
1 parent 18fa3a4 commit 86ea920
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions nano_graphrag/_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import aioboto3
from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
from azure.core.exceptions import HttpResponseError

from tenacity import (
retry,
Expand All @@ -19,6 +20,7 @@
global_openai_async_client = None
global_azure_openai_async_client = None
global_amazon_bedrock_async_client = None
global_azure_openai_emb_async_client = None


def get_openai_async_client_instance():
Expand All @@ -34,6 +36,12 @@ def get_azure_openai_async_client_instance():
global_azure_openai_async_client = AsyncAzureOpenAI()
return global_azure_openai_async_client

def get_azure_openai_emb_async_client_instance():
global global_azure_openai_emb_async_client
if global_azure_openai_emb_async_client is None:
global_azure_openai_emb_async_client = AsyncAzureOpenAI(azure_endpoint=os.environ["AZURE_ENDPOINT_EMB"])
return global_azure_openai_emb_async_client


def get_amazon_bedrock_async_client_instance():
global global_amazon_bedrock_async_client
Expand Down Expand Up @@ -287,8 +295,11 @@ async def azure_gpt_4o_mini_complete(
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
azure_openai_client = get_azure_openai_async_client_instance()

azure_openai_client = get_azure_openai_emb_async_client_instance()

response = await azure_openai_client.embeddings.create(
model="text-embedding-3-small", input=texts, encoding_format="float"
model="text-embedding-3-small", input=texts, encoding_format="float"
)

return np.array([dp.embedding for dp in response.data])

0 comments on commit 86ea920

Please sign in to comment.