diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index 974c339..979c2f7 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -4,6 +4,7 @@ import aioboto3 from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError +from azure.core.exceptions import HttpResponseError from tenacity import ( retry, @@ -19,6 +20,7 @@ global_openai_async_client = None global_azure_openai_async_client = None global_amazon_bedrock_async_client = None +global_azure_openai_emb_async_client = None def get_openai_async_client_instance(): @@ -34,6 +36,12 @@ def get_azure_openai_async_client_instance(): global_azure_openai_async_client = AsyncAzureOpenAI() return global_azure_openai_async_client +def get_azure_openai_emb_async_client_instance(): + global global_azure_openai_emb_async_client + if global_azure_openai_emb_async_client is None: + global_azure_openai_emb_async_client = AsyncAzureOpenAI(azure_endpoint=os.environ["AZURE_ENDPOINT_EMB"]) + return global_azure_openai_emb_async_client + def get_amazon_bedrock_async_client_instance(): global global_amazon_bedrock_async_client @@ -287,8 +295,11 @@ async def azure_gpt_4o_mini_complete( retry=retry_if_exception_type((RateLimitError, APIConnectionError)), ) async def azure_openai_embedding(texts: list[str]) -> np.ndarray: - azure_openai_client = get_azure_openai_async_client_instance() + + azure_openai_client = get_azure_openai_emb_async_client_instance() + response = await azure_openai_client.embeddings.create( - model="text-embedding-3-small", input=texts, encoding_format="float" + model="text-embedding-3-small", input=texts, encoding_format="float" ) + return np.array([dp.embedding for dp in response.data])