Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add RedisKVStorage, MilvusVectorStorage, and NebulaGraphStorage #71

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
- name: Build and Test
env:
NANO_GRAPHRAG_TEST_IGNORE_NEO4J: true
NANO_GRAPHRAG_TEST_IGNORE_NEBULA: true
run: |
python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./
- name: Check codecov file
Expand Down
3 changes: 3 additions & 0 deletions nano_graphrag/_storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
from .vdb_hnswlib import HNSWVectorStorage
from .vdb_nanovectordb import NanoVectorDBStorage
from .kv_json import JsonKVStorage
from .kv_redis import RedisKVStorage
from .vdb_milvus import MilvusVectorStorage
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove this line, so the Milvus is not imported by default?

from .gdb_nebula import NebulaGraphStorage
974 changes: 974 additions & 0 deletions nano_graphrag/_storage/gdb_nebula.py

Large diffs are not rendered by default.

83 changes: 83 additions & 0 deletions nano_graphrag/_storage/kv_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
from dataclasses import dataclass, field

from ..base import BaseKVStorage
import redis
from redis.exceptions import ConnectionError
from .._utils import get_workdir_last_folder_name, logger

@dataclass
class RedisKVStorage(BaseKVStorage):
_redis: redis.Redis = field(init=False, repr=False, compare=False)
def __post_init__(self):
try:
host = self.global_config["addon_params"].get("redis_host", "localhost")
port = self.global_config["addon_params"].get("redis_port", "6379")
user = self.global_config["addon_params"].get("redis_user", None)
password = self.global_config["addon_params"].get("redis_password", None)
db = self.global_config["addon_params"].get("redis_db", 0)
self._redis = redis.Redis(host=host, port=port, username=user, password=password, db=db)
self._redis.ping()
logger.info(f"Connected to Redis at {host}:{port}")
except ConnectionError:
logger.error(f"Failed to connect to Redis at {host}:{port}")
raise

self._namespace = f"kv_store_{get_workdir_last_folder_name(self.global_config['working_dir'])}"
logger.info(f"Initialized Redis KV storage for namespace: {self._namespace}")

async def all_keys(self) -> list[str]:
return [key.decode().split(':', 1)[1] for key in self._redis.keys(f"{self._namespace}:*")]

async def index_done_callback(self):
# Redis automatically persists data, so no explicit action needed
pass

async def get_by_id(self, id):
value = self._redis.get(f"{self._namespace}:{id}")
return json.loads(value) if value else None

async def get_by_ids(self, ids, fields=None):
pipeline = self._redis.pipeline()
for id in ids:
pipeline.get(f"{self._namespace}:{id}")
values = pipeline.execute()

results = []
for value in values:
if value:
data = json.loads(value)
if fields:
results.append({k: v for k, v in data.items() if k in fields})
else:
results.append(data)
else:
results.append(None)
return results

async def filter_keys(self, data: list[str]) -> set[str]:
pipeline = self._redis.pipeline()
for key in data:
pipeline.exists(f"{self._namespace}:{key}")
exists = pipeline.execute()
return set([key for key, exists in zip(data, exists) if not exists])

async def upsert(self, data: dict[str, dict]):
pipeline = self._redis.pipeline()
for key, value in data.items():
pipeline.set(f"{self._namespace}:{key}", json.dumps(value,ensure_ascii=False))
pipeline.execute()

async def drop(self):
keys = self._redis.keys(f"{self._namespace}:*")
if keys:
self._redis.delete(*keys)

def __getstate__(self):
state = self.__dict__.copy()
del state['_redis']
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.__post_init__()
86 changes: 86 additions & 0 deletions nano_graphrag/_storage/vdb_milvus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import asyncio
import os
from dataclasses import dataclass
import numpy as np
from pymilvus import MilvusClient

from .._utils import get_workdir_last_folder_name, logger
from ..base import BaseVectorStorage


@dataclass
class MilvusVectorStorage(BaseVectorStorage):

@staticmethod
def create_collection_if_not_exist(client, collection_name: str,max_id_length: int, dimension: int,**kwargs):
if client.has_collection(collection_name):
return
client.create_collection(
collection_name, max_length=max_id_length, id_type="string", dimension=dimension, **kwargs
)


def __post_init__(self):
self.milvus_uri = self.global_config["addon_params"].get("milvus_uri", "")
if self.milvus_uri:
self.milvus_user = self.global_config["addon_params"].get("milvus_user", "")
self.milvus_password = self.global_config["addon_params"].get("milvus_password", "")
self.collection_name = get_workdir_last_folder_name(self.global_config["working_dir"])
self._client = MilvusClient(self.milvus_uri, self.milvus_user, self.milvus_password)
else:
self._client_file_name = os.path.join(
self.global_config["working_dir"], "milvus_lite.db"
)
self._client = MilvusClient(self._client_file_name)

self.cosine_better_than_threshold: float = 0.2
self._max_batch_size = self.global_config["embedding_batch_num"]
self.max_id_length = 256
MilvusVectorStorage.create_collection_if_not_exist(
self._client, self.collection_name,max_id_length=self.max_id_length,dimension=self.embedding_func.embedding_dim,
)

async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.collection_name}")
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
batch_size = 1024
results = []
for i in range(0, len(list_data), batch_size):
batch = list_data[i:i+batch_size]
batch_result = self._client.upsert(collection_name=self.collection_name, data=batch)
results.append(batch_result)

total_upsert_count = sum(result.get('upsert_count', 0) for result in results)
results = {'upsert_count': total_upsert_count}
return results

async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.collection_name,
data=embedding,
limit=top_k,
output_fields=list(self.meta_fields),
search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}},
)
return [
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
for dp in results[0]
]
3 changes: 3 additions & 0 deletions nano_graphrag/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def list_of_list_to_csv(data: list[list]):
]
)

def get_workdir_last_folder_name(workdir: str) -> str:
return os.path.basename(os.path.normpath(workdir))


# -----------------------------------------------------------------------------------
# Refer the utils functions of the official GraphRAG implementation:
Expand Down
3 changes: 2 additions & 1 deletion nano_graphrag/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class GraphRAG:
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
community_report_storage_cls: Type[BaseKVStorage] = JsonKVStorage
enable_llm_cache: bool = True

# extension
Expand Down Expand Up @@ -165,7 +166,7 @@ def __post_init__(self):
else None
)

self.community_reports = self.key_string_value_json_storage_cls(
self.community_reports = self.community_report_storage_cls(
namespace="community_reports", global_config=asdict(self)
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ hnswlib
xxhash
tenacity
dspy-ai
neo4j
neo4j
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the milvus dep, so those who want to use this component will handle the cross platform. Let's not introduce the install problems for the rest.

pymilvus
redis
nebula3-python
ng_nx
Loading
Loading