From 9e4f0ba63da473fdbfd68363be8ee68742991cc2 Mon Sep 17 00:00:00 2001 From: utopia2077 Date: Mon, 7 Oct 2024 04:05:19 +0800 Subject: [PATCH 1/3] feat: add RedisKVStorage, MilvusVectorStorage, and NebulaGraphStorage --- .github/workflows/test.yml | 1 + nano_graphrag/_storage/__init__.py | 3 + nano_graphrag/_storage/gdb_nebula.py | 985 +++++++++++++++++++++++++++ nano_graphrag/_storage/kv_redis.py | 83 +++ nano_graphrag/_storage/vdb_milvus.py | 86 +++ nano_graphrag/_utils.py | 3 + nano_graphrag/graphrag.py | 3 +- tests/test_nebula_storage.py | 315 +++++++++ 8 files changed, 1478 insertions(+), 1 deletion(-) create mode 100644 nano_graphrag/_storage/gdb_nebula.py create mode 100644 nano_graphrag/_storage/kv_redis.py create mode 100644 nano_graphrag/_storage/vdb_milvus.py create mode 100644 tests/test_nebula_storage.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e975b44..048f3c3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,6 +44,7 @@ jobs: - name: Build and Test env: NANO_GRAPHRAG_TEST_IGNORE_NEO4J: true + NANO_GRAPHRAG_TEST_IGNORE_NEBULA: true run: | python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./ - name: Check codecov file diff --git a/nano_graphrag/_storage/__init__.py b/nano_graphrag/_storage/__init__.py index c8184ab..8a06881 100644 --- a/nano_graphrag/_storage/__init__.py +++ b/nano_graphrag/_storage/__init__.py @@ -3,3 +3,6 @@ from .vdb_hnswlib import HNSWVectorStorage from .vdb_nanovectordb import NanoVectorDBStorage from .kv_json import JsonKVStorage +from .kv_redis import RedisKVStorage +from .vdb_milvus import MilvusVectorStorage +from .gdb_nebula import NebulaGraphStorage diff --git a/nano_graphrag/_storage/gdb_nebula.py b/nano_graphrag/_storage/gdb_nebula.py new file mode 100644 index 0000000..5472d50 --- /dev/null +++ b/nano_graphrag/_storage/gdb_nebula.py @@ -0,0 +1,985 @@ +import asyncio +from contextlib import contextmanager +import json +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Optional, Union, Literal +import networkx as nx +from ng_nx import NebulaReader +import numpy as np +from networkx.classes.reportviews import NodeView +from nano_graphrag._storage.gdb_networkx import NetworkXStorage +from nano_graphrag.prompt import GRAPH_FIELD_SEP +from ng_nx import NebulaWriter +from ng_nx.utils import NebulaGraphConfig + +from .._utils import get_workdir_last_folder_name, logger +from ..base import ( + BaseGraphStorage, + BaseKVStorage, + SingleCommunitySchema, +) + +@dataclass +class NebulaGraphStorage(BaseGraphStorage,BaseKVStorage): + # TODO: consider add @RANK to enable multi-edge per src-->tgt + + # NOTE: client can only be used when space exists. + client: Any = None # lazy dependency, thus Any typed here. + + # ng_nx + config: Any = None + reader: Any = None + writer_cls: Any = None + _graph: Any = None + _graph_homogenous: Any = None + + # DEFAULT SCHEMA + VID_LENGTH: int = 256 + # We are using a Homogeneous Graph Model for Graph Index + INIT_EDGE_TYPE: str = "RELATED_TO" + INIT_EDGE_PROPERTIES: list[dict[str, str]] = field(default_factory=lambda: [ + {"name": "weight", "type": {"type": "float"}, "DEFAULT": 0.0}, + {"name": "description", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "order", "type": {"type": "int"}, "DEFAULT": 1}, + {"name": "source_id", "type": {"type": "string"}, "DEFAULT": "''"}, + ]) + + INIT_EDGE_INDEXES: list[dict[str, str]] = field(default_factory=lambda: [ + {"index_name": "relation_index", "fields_str": "()"}, + ]) + + INIT_VERTEX_TYPE: str = "entity" + INIT_VERTEX_PROPERTIES: list[dict[str, str]] = field(default_factory=lambda: [ + {"name": "entity_name", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "entity_type", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "description", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "source_id", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "clusters", "type": {"type": "string"}, "DEFAULT": "''"}, + ]) + INIT_VERTEX_INDEXES: list[dict[str, str]] = field(default_factory=lambda: [ + {"index_name": "entity_index", "fields_str": "()"}, + ]) + + # Schema For Meta Knowledge Graph + # Community Report Vertex Type and Edge Type + COMMUNITY_VERTEX_TYPE: str = "community" + COMMUNITY_VERTEX_PROPERTIES: list[dict[str, str]] = field(default_factory=lambda: [ + {"name": "level", "type": {"type": "int"}, "DEFAULT": -1}, + {"name": "cluster", "type": {"type": "int"}, "DEFAULT": -1}, + {"name": "title", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "report_string", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "report_json", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "chunk_ids", "type": {"type": "string"}, "DEFAULT": "''"}, + {"name": "occurrence", "type": {"type": "float"}, "DEFAULT": 0.0}, + {"name": "sub_communities", "type": {"type": "string"}, "DEFAULT": "''"}, + ]) + COMMUNITY_VERTEX_INDEXES: list[dict[str, str]] = field(default_factory=lambda: [ + {"index_name": "community_vertex_index", "fields_str": "()"}, + ]) + COMMUNITY_EDGE_TYPE: str = "ENTITY_WITHIN_COMMUNITY" + COMMUNITY_EDGE_PROPERTIES: list[dict[str, str]] = field(default_factory=lambda: [ + {"name": "level", "type": {"type": "int"}, "DEFAULT": -1}, + ]) + COMMUNITY_EDGE_INDEXES: list[dict[str, str]] = field(default_factory=lambda: [ + {"index_name": "community_edge_index", "fields_str": "()"}, + ]) + + HEATBEAT_TIME: int = 10 + IMPLEMENTED_PARAM_TYPES: list[type] = field(default_factory=lambda: [str]) + INSERT_BATCH_SIZE: int = 64 + JSON_FIELDS: list[str] = field(default_factory=lambda: ["report_json", "chunk_ids", "sub_communities"]) + + @staticmethod + def _graph_exists(session: Any, space_name: str) -> bool: + try: + spaces = session.execute("SHOW SPACES;").column_values("Name") + return any(space_name == space.as_string() for space in spaces) + except Exception as e: + error_message = f"Failed to check if graph space '{space_name}' exists: {e}" + logger.error(error_message) + raise RuntimeError(error_message) from e + + @staticmethod + def _label_exists(session: Any, space_name: str, type: Literal["tag", "edge"], label: str): + TAGS = "TAGS" + EDGES = "EDGES" + try: + session.execute(f"USE {space_name};") + rest = session.execute(f"SHOW {TAGS if type == 'tag' else EDGES};").column_values("Name") + return any(label == l.as_string() for l in rest) + except Exception as e: + error_message = f"Failed to check if '{type}' exists in graph space '{space_name}': {e}" + logger.error(error_message) + raise RuntimeError(error_message) from e + + @staticmethod + def _index_exists(session: Any, space_name: str, type: Literal["tag", "edge"], index_name: str): + try: + session.execute(f"USE {space_name};") + rest = session.execute(f"SHOW {type} INDEXES;").column_values("Index Name") + return any(index_name == i.as_string() for i in rest) + except Exception as e: + error_message = f"Failed to check if index '{index_name}' exists in graph space '{space_name}': {e}" + logger.error(error_message) + raise RuntimeError(error_message) from e + + async def _create_graph(self, session: Any, space_name: str, delay_time: float = 20, vid_length: int = 256): + from nebula3.data.ResultSet import ResultSet + + session.execute( + f"CREATE SPACE IF NOT EXISTS {space_name} (replica_factor = 1, vid_type=FIXED_STRING({vid_length}));" + ) + + backoff_time_counter = 0 + attempt = 0 + while True: + attempt += 1 + result: ResultSet = session.execute( + f"DESCRIBE SPACE {space_name}; USE {space_name};" + ) + if result.is_succeeded(): + logger.info(f"Graph space {space_name} created successfully") + break + else: + if backoff_time_counter < delay_time: + backoff_time = 2**attempt + backoff_time_counter += backoff_time + await asyncio.sleep(backoff_time) + else: + session.release() + raise ValueError( + f"Graph Space {space_name} creation failed in {backoff_time_counter} seconds" + ) + + async def _init_nebulagraph_schema(self) -> None: + with self._session() as session: + # Ensure graph space exists + if not self._graph_exists(session, self.space): + logger.info(f"Creating graph space {self.space}") + await self._create_graph(session, self.space, vid_length=self.VID_LENGTH) + + # Ensure initial schema exists + create_flag = False + if not self._label_exists(session, self.space, "tag", self.INIT_VERTEX_TYPE): + prop_list = [f"`{prop['name']}` {prop['type']['type']} DEFAULT {prop['DEFAULT']}" for prop in self.INIT_VERTEX_PROPERTIES] + prop_str = ",".join(prop_list) + CREATE_TAG_QUERY = f"CREATE TAG IF NOT EXISTS `{self.INIT_VERTEX_TYPE}` ({prop_str})" + session.execute(CREATE_TAG_QUERY) + create_flag = True + if not self._label_exists(session, self.space, "edge", self.INIT_EDGE_TYPE): + prop_list = [f"`{prop['name']}` {prop['type']['type']} DEFAULT {prop['DEFAULT']}" for prop in self.INIT_EDGE_PROPERTIES] + prop_str = ",".join(prop_list) + CREATE_EDGE_QUERY = f"CREATE EDGE IF NOT EXISTS `{self.INIT_EDGE_TYPE}` ({prop_str})" + session.execute(CREATE_EDGE_QUERY) + create_flag = True + if not self._label_exists(session, self.space, "tag", self.COMMUNITY_VERTEX_TYPE): + prop_list = [f"`{prop['name']}` {prop['type']['type']} DEFAULT {prop['DEFAULT']}" for prop in self.COMMUNITY_VERTEX_PROPERTIES] + prop_str = ",".join(prop_list) + CREATE_TAG_QUERY = f"CREATE TAG IF NOT EXISTS `{self.COMMUNITY_VERTEX_TYPE}` ({prop_str})" + session.execute(CREATE_TAG_QUERY) + create_flag = True + if not self._label_exists(session, self.space, "edge", self.COMMUNITY_EDGE_TYPE): + prop_list = [f"`{prop['name']}` {prop['type']['type']} DEFAULT {prop['DEFAULT']}" for prop in self.COMMUNITY_EDGE_PROPERTIES] + prop_str = ",".join(prop_list) + CREATE_EDGE_QUERY = f"CREATE EDGE IF NOT EXISTS `{self.COMMUNITY_EDGE_TYPE}` ({prop_str})" + session.execute(CREATE_EDGE_QUERY) + create_flag = True + + if create_flag: + # Wait for schema changes to propagate + logger.info(f"Waiting {2 * self.HEATBEAT_TIME + 1}s for schema changes to propagate") + await asyncio.sleep(2 * self.HEATBEAT_TIME + 1) + + # Verify tag creation + if not self._label_exists(session, self.space, "tag", self.INIT_VERTEX_TYPE): + raise RuntimeError(f"Failed to create tag {self.INIT_VERTEX_TYPE}") + + # Verify edge creation + if not self._label_exists(session, self.space, "edge", self.INIT_EDGE_TYPE): + raise RuntimeError(f"Failed to create edge {self.INIT_EDGE_TYPE}") + + # Ensure Meta Knowledge Graph Schema + if not self._label_exists(session, self.space, "tag", self.COMMUNITY_VERTEX_TYPE): + raise RuntimeError(f"Failed to create tag {self.COMMUNITY_VERTEX_TYPE}") + if not self._label_exists(session, self.space, "edge", self.COMMUNITY_EDGE_TYPE): + raise RuntimeError(f"Failed to create edge {self.COMMUNITY_EDGE_TYPE}") + + logger.info(f"Successfully created initial schema for graph space {self.space}") + + # Ensure initial indexes exists + create_flag = False + for tag_index in self.INIT_VERTEX_INDEXES: + if not self._index_exists(session, self.space, "tag", tag_index['index_name']): + session.execute(f"CREATE TAG INDEX IF NOT EXISTS `{tag_index['index_name']}` ON `{self.INIT_VERTEX_TYPE}` {tag_index['fields_str']}") + logger.info(f"Created tag index {tag_index['index_name']} on {self.INIT_VERTEX_TYPE}") + create_flag = True + for edge_index in self.INIT_EDGE_INDEXES: + if not self._index_exists(session, self.space, "edge", edge_index['index_name']): + session.execute(f"CREATE EDGE INDEX IF NOT EXISTS `{edge_index['index_name']}` ON `{self.INIT_EDGE_TYPE}` {edge_index['fields_str']}") + logger.info(f"Created edge index {edge_index['index_name']} on {self.INIT_EDGE_TYPE}") + create_flag = True + # Ensure Meta Knowledge Graph Indexes + for tag_index in self.COMMUNITY_VERTEX_INDEXES: + if not self._index_exists(session, self.space, "tag", tag_index['index_name']): + session.execute(f"CREATE TAG INDEX IF NOT EXISTS `{tag_index['index_name']}` ON `{self.COMMUNITY_VERTEX_TYPE}` {tag_index['fields_str']}") + logger.info(f"Created tag index {tag_index['index_name']} on {self.COMMUNITY_VERTEX_TYPE}") + create_flag = True + for edge_index in self.COMMUNITY_EDGE_INDEXES: + if not self._index_exists(session, self.space, "edge", edge_index['index_name']): + session.execute(f"CREATE EDGE INDEX IF NOT EXISTS `{edge_index['index_name']}` ON `{self.COMMUNITY_EDGE_TYPE}` {edge_index['fields_str']}") + logger.info(f"Created edge index {edge_index['index_name']} on {self.COMMUNITY_EDGE_TYPE}") + create_flag = True + + if create_flag: + # Wait for index creation to complete + logger.info(f"Waiting {2 * self.HEATBEAT_TIME + 1}s for index creation to complete") + await asyncio.sleep(2 * self.HEATBEAT_TIME + 1) + + # Verify index creation + for tag_index in self.INIT_VERTEX_INDEXES: + if not self._index_exists(session, self.space, "tag", tag_index['index_name']): + raise RuntimeError(f"Failed to create tag index {tag_index['index_name']}") + for edge_index in self.INIT_EDGE_INDEXES: + if not self._index_exists(session, self.space, "edge", edge_index['index_name']): + raise RuntimeError(f"Failed to create edge index {edge_index['index_name']}") + + # Ensure Meta Knowledge Graph Indexes + for tag_index in self.COMMUNITY_VERTEX_INDEXES: + if not self._index_exists(session, self.space, "tag", tag_index['index_name']): + raise RuntimeError(f"Failed to create tag index {tag_index['index_name']}") + for edge_index in self.COMMUNITY_EDGE_INDEXES: + if not self._index_exists(session, self.space, "edge", edge_index['index_name']): + raise RuntimeError(f"Failed to create edge index {edge_index['index_name']}") + + def get_graphd_addresses(self) -> list[tuple[str, int]]: + graphd_host_address: list[tuple[str, int]] = [] + for host in self.graphd_hosts.split(","): + # sanity check + if ":" not in host: + raise ValueError(f"Invalid graphd host {host}, should be host:port") + host, port = host.split(":") + if not port.isdigit(): + raise ValueError(f"Invalid port {port} in host {host}") + if int(port) < 0 or int(port) > 65535: + raise ValueError(f"Invalid port {port} in host {host}") + if not host: + raise ValueError(f"Invalid host {host}") + graphd_host_address.append((host, int(port))) + return graphd_host_address + + @contextmanager + def _session(self) -> Any: + """ + Only used for space creation and schema initialization. + """ + from nebula3.Config import Config, SSL_config + from nebula3.gclient.net.ConnectionPool import ConnectionPool + + conn_pool = ConnectionPool() + graphd_host_address_list = self.get_graphd_addresses() + try: + + conn_pool.init( + graphd_host_address_list, + configs=Config(), + ssl_conf=SSL_config() if self.use_tls else None + ) + + session = conn_pool.get_session(self.username, self.password) + try: + yield session + finally: + session.release() + finally: + conn_pool.close() + + def _initialize_session_pool(self) -> None: + """ + Initialize and set up the session pool as a singleton. + The space is created and schema initialized before this method is called. + """ + from nebula3.gclient.net.SessionPool import SessionPool + from nebula3.Config import SessionPoolConfig, SSL_config + + graphd_host_address = self.get_graphd_addresses() + try: + session_pool = SessionPool( + self.username, + self.password, + self.space, + graphd_host_address + ) + + session_pool_config = SessionPoolConfig() + session_pool.init( + session_pool_config, + ssl_configs=SSL_config() if self.use_tls else None + ) + self.client: SessionPool = session_pool + except Exception as e: + raise RuntimeError(f"Failed to initialize session pool: {e}") from e + + def __post_init__(self): + self.space: str = get_workdir_last_folder_name(self.global_config["working_dir"]) + self.use_tls: bool = self.global_config["addon_params"].get("use_tls", False) + self.graphd_hosts: str = self.global_config["addon_params"].get("graphd_hosts", None) + self.metad_hosts: str = self.global_config["addon_params"].get("metad_hosts", None) + self.username: str = self.global_config["addon_params"].get("username", "root") + self.password: str = self.global_config["addon_params"].get("password", "nebula") + self.VID_LENGTH: int = 256 + + if not self.graphd_hosts or not self.metad_hosts: + raise ValueError("Missing required connection information: graphd_hosts and metad_hosts not provided") + + asyncio.run(self._init_nebulagraph_schema()) + + self.config = NebulaGraphConfig( + space=self.space, + graphd_hosts=self.graphd_hosts, + metad_hosts=self.metad_hosts + ) + self.reader = MyNebulaReader( + edges=[self.INIT_EDGE_TYPE], + properties=[[prop['name'] for prop in self.INIT_EDGE_PROPERTIES]], + nebula_config=self.config, + limit=1000000 + ) + self.writer_cls = NebulaWriter + self._graph = None + self._initialize_session_pool() + self._clustering_algorithms = { + "leiden": self._leiden_clustering, + } + self._node_embed_algorithms = { + "node2vec": self._node2vec_embed, + } + + + async def has_node(self, node_id: str) -> bool: + if node_id is None or not isinstance(node_id, str): + raise ValueError(f"Invalid node_id {node_id}") + if node_id == "": + raise ValueError(f"Invalid node_id {node_id}") + + if len(node_id.encode('utf-8')) > self.VID_LENGTH: + return False + + try: + result_n = self.client.execute_py( + f"MATCH (n:{self.INIT_VERTEX_TYPE}) WHERE id(n) == $node_id RETURN n;", params={"node_id": node_id} + ).column_values("n") + return len(result_n) > 0 + except Exception as e: + raise RuntimeError(f"Failed to check if node {node_id} exists: {e}") from e + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + if source_node_id is None or not isinstance(source_node_id, str) or source_node_id == "": + raise ValueError(f"Invalid source_node_id {source_node_id}") + if target_node_id is None or not isinstance(target_node_id, str) or target_node_id == "": + raise ValueError(f"Invalid target_node_id {target_node_id}") + if len(source_node_id.encode('utf-8')) > self.VID_LENGTH or len(target_node_id.encode('utf-8')) > self.VID_LENGTH: + return False + + (sorted_source_node_id, sorted_target_node_id) = sorted([source_node_id, target_node_id]) + + try: + result = self.client.execute_py( + f"MATCH (n)-[e:{self.INIT_EDGE_TYPE}]-(m) WHERE id(n) == $source_node_id AND id(m) == $target_node_id RETURN e", + params={"source_node_id": sorted_source_node_id, "target_node_id": sorted_target_node_id} + ).column_values("e") + return len(result) > 0 + except Exception as e: + raise RuntimeError(f"Failed to check if edge exists between {sorted_source_node_id} and {sorted_target_node_id}: {e}") from e + + + async def get_node(self, node_id: str) -> Union[dict, NodeView, None]: + if node_id is None or not isinstance(node_id, str) or node_id == "": + raise ValueError(f"Invalid node_id {node_id}") + if len(node_id.encode('utf-8')) > self.VID_LENGTH: + return None + + try: + result = self.client.execute_py( + f"MATCH (n:{self.INIT_VERTEX_TYPE}) WHERE id(n) == $node_id RETURN n", + params={"node_id": node_id} + ).column_values("n") + + if not result: + return None + + node = result[0].as_node() + return { + "id": node_id, + **{k: v.cast() for k, v in node.properties().items() if not (k == 'clusters' and v.cast() == '')} # compatibility for _find_most_related_community_from_entities + } + except Exception as e: + raise RuntimeError(f"Failed to get node {node_id}: {e}") from e + + async def node_degree(self, node_id: str) -> int: + if node_id is None or not isinstance(node_id, str) or node_id == "": + raise ValueError(f"Invalid node_id {node_id}") + try: + result = self.client.execute_py( + f"MATCH (n)-[e:{self.INIT_EDGE_TYPE}]-() WHERE id(n) == $node_id RETURN count(e) AS Degree", + params={"node_id": node_id} + ).column_values("Degree") + if not result: + return 0 + return result[0].cast() + except Exception as e: + raise RuntimeError(f"Failed to get node degree for {node_id}: {e}") from e + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """ + ref:https://github.com/microsoft/graphrag/blob/718d1ef441137a6aed3bb9c445aabbdf612c03b9/graphrag/index/verbs/graph/compute_edge_combined_degree.py + it is defined as the number of neighbors of the source node plus the number of neighbors of the target node. + """ + if src_id is None or not isinstance(src_id, str) or src_id == "": + raise ValueError(f"Invalid src_id {src_id}") + if tgt_id is None or not isinstance(tgt_id, str) or tgt_id == "": + raise ValueError(f"Invalid tgt_id {tgt_id}") + + try: + source_degree = await self.node_degree(src_id) + target_degree = await self.node_degree(tgt_id) + return source_degree + target_degree + except Exception as e: + raise RuntimeError(f"Failed to compute edge degree between {src_id} and {tgt_id}: {e}") from e + + + async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: + if source_node_id is None or not isinstance(source_node_id, str) or source_node_id == "": + raise ValueError(f"Invalid source_node_id {source_node_id}") + if target_node_id is None or not isinstance(target_node_id, str) or target_node_id == "": + raise ValueError(f"Invalid target_node_id {target_node_id}") + + # NebulaGraph is unordered, so we need to sort the source and target node ids + (sorted_source_node_id, sorted_target_node_id) = sorted([source_node_id, target_node_id]) + + try: + result: list[dict] = self.client.execute_py( + f"MATCH (n)-[e:{self.INIT_EDGE_TYPE}]-(m) WHERE id(n) == $src_id AND id(m) == $tgt_id RETURN e", + params={"src_id": sorted_source_node_id, "tgt_id": sorted_target_node_id} + ).as_primitive() + + if not result: + return None + + if len(result) > 1: + logger.warning(f"Found multiple edges between {source_node_id} and {target_node_id}") + + edge_primitive = result[0] + if not edge_primitive or not edge_primitive.get('e'): + return None + edge_props = {k:v.cast() for k,v in edge_primitive['e']['props'].items()} + edge = { + 'source_node_id': source_node_id, + 'target_node_id': target_node_id, + **edge_props + } + return edge + + except Exception as e: + raise RuntimeError(f"Failed to get edge between {source_node_id} and {target_node_id}: {e}") from e + + async def get_node_edges( + self, source_node_id: str + ) -> Union[list[tuple[str, str]], None]: + if source_node_id is None or not isinstance(source_node_id, str) or source_node_id == "": + raise ValueError(f"Invalid source_node_id {source_node_id}") + try: + result = self.client.execute_py( + f"MATCH (n)-[e:{self.INIT_EDGE_TYPE}]-(m) WHERE id(n) == $src_id RETURN id(n) AS source, id(m) AS target", + params={"src_id": source_node_id} + ).as_primitive() + + if not result: + return [] + + edges = [(edge['source'], edge['target']) for edge in result] + return edges if edges else None + except Exception as e: + raise RuntimeError(f"Failed to get edges for node {source_node_id}: {e}") from e + + async def upsert_node(self, node_id: str, node_data: dict[str, str], label: Optional[str] = None): + if node_id is None or not isinstance(node_id, str) or node_id == "": + raise ValueError(f"Invalid node_id {node_id}") + if node_data is None or not isinstance(node_data, dict): + raise ValueError(f"Invalid node_data {node_data}") + if not node_data: + raise ValueError(f"Invalid node_data {node_data}") + if len(node_id.encode('utf-8')) > self.VID_LENGTH: + return + + from uuid import uuid4 + label = label or self.INIT_VERTEX_TYPE + if 'entity_name' not in node_data: + node_data['entity_name'] = node_id + prop_all_names = list(node_data.keys()) + prop_name = ",".join( + [prop for prop in prop_all_names if node_data[prop] is not None] + ) + props_ngql: list[str] = [] + props_map: dict[str, Any] = {} + for prop in prop_all_names: + if node_data[prop] is None: + continue + if any([isinstance(node_data[prop], t) for t in self.IMPLEMENTED_PARAM_TYPES]): + new_key = "k_" + uuid4().hex + props_ngql.append(f"${new_key}") + props_map[new_key] = node_data[prop] + else: + props_ngql.append(str(node_data[prop])) + prop_val = ",".join(props_ngql) + + query = ( + f"INSERT VERTEX `{label}`({prop_name}) " + f" VALUES '{escape_bucket(node_id)}':({prop_val});\n" + ) + logger.debug(f"upsert_node()\nDML query: {query}") + result = self.client.execute_py(query, props_map) + + if not result.is_succeeded(): + raise RuntimeError(f"Failed to upsert node {escape_bucket(node_id)}: {result} with query {query}") + + async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str], label: Optional[str] = None): + if source_node_id is None or not isinstance(source_node_id, str) or source_node_id == "": + raise ValueError(f"Invalid source_node_id {source_node_id}") + if target_node_id is None or not isinstance(target_node_id, str) or target_node_id == "": + raise ValueError(f"Invalid target_node_id {target_node_id}") + if edge_data is None or not isinstance(edge_data, dict): + raise ValueError(f"Invalid edge_data {edge_data}") + if not edge_data: + raise ValueError(f"Invalid edge_data {edge_data}") + + if len(source_node_id.encode('utf-8')) > self.VID_LENGTH or len(target_node_id.encode('utf-8')) > self.VID_LENGTH: + return + + (sorted_source_node_id, sorted_target_node_id) = sorted([source_node_id, target_node_id]) + + from uuid import uuid4 + label = label or self.INIT_EDGE_TYPE + + prop_all_names = list(edge_data.keys()) + prop_name = ",".join( + [f"`{prop}`" for prop in prop_all_names if edge_data[prop] is not None] + ) + props_ngql: list[str] = [] + props_map: dict[str, Any] = {} + for prop in prop_all_names: + if edge_data[prop] is None: + continue + if any([isinstance(edge_data[prop], t) for t in self.IMPLEMENTED_PARAM_TYPES]): + new_key = "k_" + uuid4().hex + props_ngql.append(f"${new_key}") + props_map[new_key] = edge_data[prop] + else: + props_ngql.append(str(edge_data[prop])) + prop_val = ",".join(props_ngql) + query = ( + f"INSERT EDGE `{label}`({prop_name}) " + f" VALUES '{escape_bucket(sorted_source_node_id)}'->'{escape_bucket(sorted_target_node_id)}':({prop_val});\n" + ) + logger.debug(f"upsert_edge()\nDML query: {query}") + result = self.client.execute_py(query, props_map) + if not result.is_succeeded(): + raise RuntimeError(f"Failed to upsert edge between {escape_bucket(sorted_source_node_id)} and {escape_bucket(sorted_target_node_id)}: {result} with query {query}") + + + async def clustering(self, algorithm: str): + if algorithm not in self._clustering_algorithms: + raise ValueError(f"Clustering algorithm {algorithm} not supported") + await self._clustering_algorithms[algorithm]() + + + async def _cluster_data_to_graph(self, cluster_data: dict[str, list[dict[str, str]]]): + """ + 将社区数据转换为图数据,并写入 NebulaGraph + :param cluster_data: 社区数据,字典类型,键为 entity ID,值为社区数据列表 + """ + # community node: (level, cluster), with id cluster_{cluster} + # cluster_{cluster} is the key, and value is a list of (level, cluster) + data: dict[str, list[int, int]] = defaultdict(list) + # entity --> cluster + # (entity_id, cluster_id, level) + edges: list[tuple[int, int, int]] = [] + rank = 0 # just a placeholder, no used yet + for node_id, clusters in cluster_data.items(): + for cluster in clusters: + cluster_id = cluster["cluster"] + level = cluster["level"] + cluster_node_id = f'{cluster_id}' + if cluster_node_id not in data: + data[cluster_node_id] = [level, cluster_id] + edges.append((escape_bucket(node_id), cluster_node_id, rank, [level])) + + # add clusters info to entity node , for compatibility + async def update_node_clusters(node_id, clusters): + if await self.has_node(node_id): + clusters_json = json.dumps(clusters, ensure_ascii=False) + update_query = ( + f"UPDATE VERTEX ON {self.INIT_VERTEX_TYPE} '{escape_bucket(node_id)}' " + f"SET clusters = '{clusters_json}';" + ) + result = self.client.execute_py(update_query) + if not result.is_succeeded(): + raise RuntimeError(f"update {node_id} clusters failed: {result.error_msg()}") + + update_tasks = [] + for node_id, clusters in cluster_data.items(): + # TODO 确认这里的clusters只有数字 + update_tasks.append(update_node_clusters(node_id, clusters)) + + await asyncio.gather(*update_tasks) + + # Community Vertex data + ng_nx_community_writer = self.writer_cls( + data=data, + nebula_config=self.config, + ) + ng_nx_community_writer.set_options( + label=self.COMMUNITY_VERTEX_TYPE, + properties=["level", "cluster"], + write_mode="insert", + sink="nebulagraph_vertex", + ) + ng_nx_community_writer.write() + + ng_nx_community_edge_writer = self.writer_cls( + data=edges, + nebula_config=self.config, + ) + ng_nx_community_edge_writer.set_options( + label=self.COMMUNITY_EDGE_TYPE, + properties=["level"], + write_mode="insert", + sink="nebulagraph_edge", + ) + ng_nx_community_edge_writer.write() + + return + + + async def _leiden_clustering(self): + from graspologic.partition import hierarchical_leiden + + # TODO: introduce Cache mechanism for this. + self._graph = self.reader.read() + self._graph_homogenous = nx.Graph(self._graph) + + graph = NetworkXStorage.stable_largest_connected_component(self._graph_homogenous) + community_mapping = hierarchical_leiden( + graph, + max_cluster_size=self.global_config["max_graph_cluster_size"], + random_seed=self.global_config["graph_cluster_seed"], + ) + + node_communities: dict[str, list[dict[str, str]]] = defaultdict(list) + __levels = defaultdict(set) + + for partition in community_mapping: + level_key = partition.level + cluster_id = partition.cluster + node_communities[partition.node].append( + {"level": level_key, "cluster": cluster_id} + ) + __levels[level_key].add(cluster_id) + node_communities = dict(node_communities) + __levels = {k: len(v) for k, v in __levels.items()} + logger.info(f"Each level has communities: {dict(__levels)}") + await self._cluster_data_to_graph(node_communities) + + async def community_schema(self) -> dict[str, SingleCommunitySchema]: + # 初始化一个默认字典,字典的默认值是一个包含社区信息的字典 + results = defaultdict( + lambda: dict( + level=None, # 社区的层级 + title=None, # 社区的标题 + edges=set(), # 社区内的边集合 + nodes=set(), # 社区内的节点集合 + chunk_ids=set(), # 社区内的分块ID集合 + occurrence=0.0, # 社区的出现频率 + sub_communities=[], # 子社区列表 + ) + ) + max_num_ids = 0 # 初始化最大ID数量 + levels = defaultdict(set) # 初始化一个默认字典,用于存储每个层级的社区 + + + communities_result = self.client.execute_py( + f"MATCH (n:{self.COMMUNITY_VERTEX_TYPE}) RETURN n" + ).column_values("n") + + for community in communities_result: + community_data = community.as_node() + community_id = community_data.get_id().cast() + community_properties = {k: v.cast() for k, v in community_data.properties().items()} + level = community_properties.get("level") + community_key = str(community_properties.get("cluster")) + title = community_id + + # 获取community连接的node + node_query = f"MATCH (n:{self.INIT_VERTEX_TYPE})-[:{self.COMMUNITY_EDGE_TYPE}]->(c:{self.COMMUNITY_VERTEX_TYPE}) WHERE id(c) == '{community_id}' RETURN n;" + nodes_in_community = self.client.execute_py(node_query) + nodes = set() + edges = set() + chunk_ids = set() + for node in nodes_in_community.column_values("n"): + node_data = node.as_node() + node_id = node_data.get_id().cast() + nodes.add(node_id) + node_properties = {k: v.cast() for k, v in node_data.properties().items()} + chunk_ids.update(node_properties.get("source_id", "").split(GRAPH_FIELD_SEP)) + + # 获取node连接的边, 处理成(src,tgt)的形式 + edge_query = f"MATCH (n:{self.INIT_VERTEX_TYPE})-[:{self.INIT_EDGE_TYPE}]-(m:{self.INIT_VERTEX_TYPE}) WHERE id(n) == '{escape_bucket(node_id)}' RETURN m;" + edges_in_node = self.client.execute_py(edge_query) + for edge in edges_in_node.column_values("m"): + dst_node = edge.as_node() + dst_node_id = dst_node.get_id().cast() + edges.add((node_id, dst_node_id)) + + results[community_key].update( + level=level, + title=title, + chunk_ids=chunk_ids, + nodes=nodes, + edges=edges, + ) + + levels[level].add(community_key) + max_num_ids = max(max_num_ids, len(chunk_ids)) + + ordered_levels = sorted(levels.keys()) # 对层级进行排序 + # 遍历排序后的层级,计算子社区 + for i, curr_level in enumerate(ordered_levels[:-1]): + next_level = ordered_levels[i + 1] # 获取下一个层级 + this_level_comms = levels[curr_level] # 获取当前层级的社区 + next_level_comms = levels[next_level] # 获取下一个层级的社区 + + for comm in this_level_comms: + # 遍历当前层级的社区,计算子社区 + results[comm]["sub_communities"] = [ + c + for c in next_level_comms + if results[c]["nodes"].issubset(results[comm]["nodes"]) # 如果下一个层级的社区节点是当前层级社区节点的子集,则将其添加为子社区 + ] + + # 将集合类型的字段转换为列表,并计算社区的出现频率 + for k, v in results.items(): + v["edges"] = list(v["edges"]) + v["edges"] = [list(e) for e in v["edges"]] + v["nodes"] = list(v["nodes"]) + v["chunk_ids"] = list(v["chunk_ids"]) + v["occurrence"] = len(v["chunk_ids"]) / max_num_ids # 计算社区的出现频率 + return dict(results) + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + if algorithm not in self._node_embed_algorithms: + raise ValueError(f"Node embedding algorithm {algorithm} not supported") + # TODO: persist node2vec/graph embedding to NebulaGraph + return await self._node_embed_algorithms[algorithm]() + + async def _node2vec_embed(self): + from graspologic import embed + + # TOOD: introduce Cache mechanism for this to scale even better. + self._graph = self.reader.read() + self._graph_homogenous = nx.Graph(self._graph) + + embeddings, nodes = embed.node2vec_embed( + self._graph, + **self.global_config["node2vec_params"], + ) + + nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes] + return embeddings, nodes_ids + + async def index_done_callback(self): + # close the client + self.client.close() + # TODO: introduce cache mechnism, then we could leverage this callback + + ## ↓ KV Storage Implementation ↓ ## + + async def all_keys(self) -> list[str]: + communities_result = self.client.execute_py( + f"MATCH (n:{self.COMMUNITY_VERTEX_TYPE}) RETURN n" + ).column_values("n") + return [c.as_node().get_id().cast() for c in communities_result] + + async def get_by_id(self, id): + try: + result = self.client.execute_py( + f"MATCH (n:{self.COMMUNITY_VERTEX_TYPE}) WHERE id(n) == $id RETURN n", + params={"id": id} + ).column_values("n") + + if not result: + return None + + node = result[0].as_node() + properties = {k: v.cast() for k, v in node.properties().items()} + properties = self._parse_json_fields(properties, self.JSON_FIELDS) + return properties + except Exception as e: + raise RuntimeError(f"Failed to get community {id}: {e}") from e + + async def get_by_ids(self, ids, fields=None): + try: + id_list = ", ".join([f"'{id}'" for id in ids]) + query = f"MATCH (n:{self.COMMUNITY_VERTEX_TYPE}) WHERE id(n) IN [{id_list}] RETURN n" + result = self.client.execute_py(query).column_values("n") + + communities = [] + for node in result: + properties = {k: v.cast() for k, v in node.as_node().properties().items()} + properties = self._parse_json_fields(properties, self.JSON_FIELDS) + + if fields: + properties = {k: v for k, v in properties.items() if k in fields} + communities.append(properties) + + ordered_communities = [] + for id in ids: + community = next((c for c in communities if c.get('id') == id), None) + ordered_communities.append(community) + + return ordered_communities + except Exception as e: + raise RuntimeError(f"Failed to get communities: {e}") from e + + async def filter_keys(self, data: list[str]) -> set[str]: + raise NotImplementedError + + async def upsert(self, data: dict[str, dict]): + async def upsert_one(id , properties): + try: + update_properties = { + 'report_json': properties.get('report_json', '{}'), + 'chunk_ids': properties.get('chunk_ids', '[]'), + 'sub_communities': properties.get('sub_communities', '[]'), + 'occurrence': properties.get('occurrence', 0.0), + 'title': properties.get('report_json', '{}').get('title', ''), + 'report_string': properties.get('report_string', '') + } + + update_properties = self._dump_json_fields(update_properties, self.JSON_FIELDS) + + update_query = f"UPDATE VERTEX ON {self.COMMUNITY_VERTEX_TYPE} '{id}' SET report_json = $report_json, chunk_ids = $chunk_ids, sub_communities = $sub_communities, occurrence = $occurrence, title = $title, report_string = $report_string" + + result = self.client.execute_py(update_query, params=update_properties) + if not result.is_succeeded(): + raise RuntimeError(f"Update community {id} Failed: {result.error_msg()}") + + except Exception as e: + raise RuntimeError(f"Update community {id} Failed: {e}") from e + + tasks = [] + for id, properties in data.items(): + task = upsert_one(id, properties) + tasks.append(task) + + await asyncio.gather(*tasks) + logger.info(f"Successfully updated {len(data)} communities") + + async def drop(self): + try: + # Delete all communitites vertices and edges + delete_query = f"LOOKUP ON {self.COMMUNITY_VERTEX_TYPE} YIELD id(vertex) AS ID | DELETE VERTEX $-.ID WITH EDGE;" + result = self.client.execute_py(delete_query) + if not result.is_succeeded(): + raise RuntimeError(f"Failed to delete community vertices: {result.error_msg()}") + + logger.info("Successfully dropped all community vertices and edges") + except Exception as e: + raise RuntimeError(f"Failed to drop community data: {e}") from e + + def _parse_json_fields(self, properties, json_fields): + """ + Parse the json fields in the properties. + """ + for k, v in properties.items(): + if k in json_fields: + properties[k] = json.loads(v) + return properties + + def _dump_json_fields(self, properties, json_fields): + """ + Dump the json fields in the properties. + """ + for k, v in properties.items(): + if k in json_fields: + properties[k] = json.dumps(v,ensure_ascii=False) + return properties + + + + +class MyNebulaReader(NebulaReader): + """ + If the property has the name 'order', it will cause an error. + Therefore, it needs to be enclosed with backticks. + """ + def read(self): + from ng_nx.utils import result_to_df + with self.connection_pool.session_context( + self.nebula_user, self.nebula_password + ) as session: + assert session.execute( + f"USE {self.space}" + ).is_succeeded(), f"Failed to use space {self.space}" + result_list = [] + g = nx.MultiDiGraph() + for i in range(len(self.edges)): + edge = self.edges[i] + properties = self.properties[i] + properties_query_field = "" + for property in properties: + properties_query_field += f", e.`{property}` AS `{property}`" + if self.with_rank: + properties_query_field += ", rank(e) AS `__rank__`" + result = session.execute( + f"MATCH ()-[e:`{edge}`]->() RETURN src(e) AS src, dst(e) AS dst{properties_query_field} LIMIT {self.limit}" + ) + # print(f'query: MATCH ()-[e:`{edge}`]->() RETURN src(e) AS src, dst(e) AS dst{properties_query_field} LIMIT {self.limit}') + # print(f"Result: {result}") + assert result.is_succeeded() + result_list.append(result) + + # merge all result + for i, result in enumerate(result_list): + _df = result_to_df(result) + # TBD, consider add label of edge + properties = self.properties[i] if self.properties[i] else None + if self.with_rank: + properties = properties + ["__rank__"] + _g = nx.from_pandas_edgelist( + _df, + "src", + "dst", + properties, + create_using=nx.MultiDiGraph(), + edge_key="__rank__", + ) + else: + _g = nx.from_pandas_edgelist( + _df, + "src", + "dst", + properties, + create_using=nx.MultiDiGraph(), + ) + g = nx.compose(g, _g) + return g + +def escape_bucket(input): + if isinstance(input,str): + return input.replace("'", "\\'").replace('"', '\\"') + elif isinstance(input,list): + return [escape_bucket(i) for i in input] + elif isinstance(input,dict): + return {k: escape_bucket(v) for k, v in input.items()} + else: + return input \ No newline at end of file diff --git a/nano_graphrag/_storage/kv_redis.py b/nano_graphrag/_storage/kv_redis.py new file mode 100644 index 0000000..c6d9172 --- /dev/null +++ b/nano_graphrag/_storage/kv_redis.py @@ -0,0 +1,83 @@ +import json +from dataclasses import dataclass, field + +from ..base import BaseKVStorage +import redis +from redis.exceptions import ConnectionError +from .._utils import get_workdir_last_folder_name, logger + +@dataclass +class RedisKVStorage(BaseKVStorage): + _redis: redis.Redis = field(init=False, repr=False, compare=False) + def __post_init__(self): + try: + host = self.global_config["addon_params"].get("redis_host", "localhost") + port = self.global_config["addon_params"].get("redis_port", "6379") + user = self.global_config["addon_params"].get("redis_user", None) + password = self.global_config["addon_params"].get("redis_password", None) + db = self.global_config["addon_params"].get("redis_db", 0) + self._redis = redis.Redis(host=host, port=port, username=user, password=password, db=db) + self._redis.ping() + logger.info(f"Connected to Redis at {host}:{port}") + except ConnectionError: + logger.error(f"Failed to connect to Redis at {host}:{port}") + raise + + self._namespace = f"kv_store_{get_workdir_last_folder_name(self.global_config['working_dir'])}" + logger.info(f"Initialized Redis KV storage for namespace: {self._namespace}") + + async def all_keys(self) -> list[str]: + return [key.decode().split(':', 1)[1] for key in self._redis.keys(f"{self._namespace}:*")] + + async def index_done_callback(self): + # Redis automatically persists data, so no explicit action needed + pass + + async def get_by_id(self, id): + value = self._redis.get(f"{self._namespace}:{id}") + return json.loads(value) if value else None + + async def get_by_ids(self, ids, fields=None): + pipeline = self._redis.pipeline() + for id in ids: + pipeline.get(f"{self._namespace}:{id}") + values = pipeline.execute() + + results = [] + for value in values: + if value: + data = json.loads(value) + if fields: + results.append({k: v for k, v in data.items() if k in fields}) + else: + results.append(data) + else: + results.append(None) + return results + + async def filter_keys(self, data: list[str]) -> set[str]: + pipeline = self._redis.pipeline() + for key in data: + pipeline.exists(f"{self._namespace}:{key}") + exists = pipeline.execute() + return set([key for key, exists in zip(data, exists) if not exists]) + + async def upsert(self, data: dict[str, dict]): + pipeline = self._redis.pipeline() + for key, value in data.items(): + pipeline.set(f"{self._namespace}:{key}", json.dumps(value,ensure_ascii=False)) + pipeline.execute() + + async def drop(self): + keys = self._redis.keys(f"{self._namespace}:*") + if keys: + self._redis.delete(*keys) + + def __getstate__(self): + state = self.__dict__.copy() + del state['_redis'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.__post_init__() \ No newline at end of file diff --git a/nano_graphrag/_storage/vdb_milvus.py b/nano_graphrag/_storage/vdb_milvus.py new file mode 100644 index 0000000..2e89c45 --- /dev/null +++ b/nano_graphrag/_storage/vdb_milvus.py @@ -0,0 +1,86 @@ +import asyncio +import os +from dataclasses import dataclass +import numpy as np +from pymilvus import MilvusClient + +from .._utils import get_workdir_last_folder_name, logger +from ..base import BaseVectorStorage + + +@dataclass +class MilvusVectorStorage(BaseVectorStorage): + + @staticmethod + def create_collection_if_not_exist(client, collection_name: str,max_id_length: int, dimension: int,**kwargs): + if client.has_collection(collection_name): + return + client.create_collection( + collection_name, max_length=max_id_length, id_type="string", dimension=dimension, **kwargs + ) + + + def __post_init__(self): + self.milvus_uri = self.global_config["addon_params"].get("milvus_uri", "") + if self.milvus_uri: + self.milvus_user = self.global_config["addon_params"].get("milvus_user", "") + self.milvus_password = self.global_config["addon_params"].get("milvus_password", "") + self.collection_name = get_workdir_last_folder_name(self.global_config["working_dir"]) + self._client = MilvusClient(self.milvus_uri, self.milvus_user, self.milvus_password) + else: + self._client_file_name = os.path.join( + self.global_config["working_dir"], "milvus_lite.db" + ) + self._client = MilvusClient(self._client_file_name) + + self.cosine_better_than_threshold: float = 0.2 + self._max_batch_size = self.global_config["embedding_batch_num"] + self.max_id_length = 256 + MilvusVectorStorage.create_collection_if_not_exist( + self._client, self.collection_name,max_id_length=self.max_id_length,dimension=self.embedding_func.embedding_dim, + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.collection_name}") + list_data = [ + { + "id": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["vector"] = embeddings[i] + batch_size = 1024 + results = [] + for i in range(0, len(list_data), batch_size): + batch = list_data[i:i+batch_size] + batch_result = self._client.upsert(collection_name=self.collection_name, data=batch) + results.append(batch_result) + + total_upsert_count = sum(result.get('upsert_count', 0) for result in results) + results = {'upsert_count': total_upsert_count} + return results + + async def query(self, query: str, top_k=5): + embedding = await self.embedding_func([query]) + results = self._client.search( + collection_name=self.collection_name, + data=embedding, + limit=top_k, + output_fields=list(self.meta_fields), + search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}}, + ) + return [ + {**dp["entity"], "id": dp["id"], "distance": dp["distance"]} + for dp in results[0] + ] diff --git a/nano_graphrag/_utils.py b/nano_graphrag/_utils.py index 5185959..bea6829 100644 --- a/nano_graphrag/_utils.py +++ b/nano_graphrag/_utils.py @@ -134,6 +134,9 @@ def list_of_list_to_csv(data: list[list]): ] ) +def get_workdir_last_folder_name(workdir: str) -> str: + return os.path.basename(os.path.normpath(workdir)) + # ----------------------------------------------------------------------------------- # Refer the utils functions of the official GraphRAG implementation: diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index 2c9e1be..cfc93e9 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -122,6 +122,7 @@ class GraphRAG: vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage vector_db_storage_cls_kwargs: dict = field(default_factory=dict) graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage + community_report_storage_cls: Type[BaseKVStorage] = JsonKVStorage enable_llm_cache: bool = True # extension @@ -165,7 +166,7 @@ def __post_init__(self): else None ) - self.community_reports = self.key_string_value_json_storage_cls( + self.community_reports = self.community_report_storage_cls( namespace="community_reports", global_config=asdict(self) ) self.chunk_entity_relation_graph = self.graph_storage_cls( diff --git a/tests/test_nebula_storage.py b/tests/test_nebula_storage.py new file mode 100644 index 0000000..595b50e --- /dev/null +++ b/tests/test_nebula_storage.py @@ -0,0 +1,315 @@ +from functools import wraps +import os +import pytest +import numpy as np +import json +from nano_graphrag import GraphRAG +from nano_graphrag._storage import NebulaGraphStorage +from nano_graphrag._utils import wrap_embedding_func_with_attrs + +if os.environ.get("NANO_GRAPHRAG_TEST_IGNORE_NEBULA", False): + pytest.skip("skipping nebula tests", allow_module_level=True) + +@wrap_embedding_func_with_attrs(embedding_dim=384, max_token_size=8192) +async def mock_embedding(texts: list[str]) -> np.ndarray: + return np.random.rand(len(texts), 384) + + +@pytest.fixture(scope="module") +def nebula_config(): + return { + "graphd_hosts": os.environ.get("NEBULA_GRAPHD_HOSTS", "localhost:9669"), + "metad_hosts": os.environ.get("NEBULA_METAD_HOSTS","localhost:9559"), + "username": os.environ.get("NEBULA_USERNAME", "root"), + "password": os.environ.get("NEBULA_PASSWORD", "nebula"), + } + + +@pytest.fixture +def nebula_graph_storage(nebula_config): + rag = GraphRAG( + working_dir="./tests/nebula_test", + embedding_func=mock_embedding, + graph_storage_cls=NebulaGraphStorage, + community_report_storage_cls = NebulaGraphStorage, + addon_params=nebula_config, + ) + storage = rag.chunk_entity_relation_graph + return storage + +@pytest.fixture +def nebula_kv_storage(nebula_config): + rag = GraphRAG( + working_dir="./tests/nebula_test", + embedding_func=mock_embedding, + graph_storage_cls=NebulaGraphStorage, + community_report_storage_cls = NebulaGraphStorage, + addon_params=nebula_config, + ) + storage = rag.community_reports + return storage + + +def test_nebula_storage_init(): + rag = GraphRAG( + working_dir="./tests/neo4j_test", + embedding_func=mock_embedding, + ) + with pytest.raises(ValueError): + storage = NebulaGraphStorage( + namespace="nanographrag_test", global_config=rag.__dict__ + ) + +def delete_all_data(nebula_graph_storage): + """Delete all tags and edges from the Nebula Graph database.""" + TAG_NAMES = [nebula_graph_storage.INIT_VERTEX_TYPE, nebula_graph_storage.COMMUNITY_VERTEX_TYPE] + # EDGE_NAMES = [nebula_graph_storage.COMMUNITY_EDGE_TYPE, nebula_graph_storage.COMMUNITY_EDGE_TYPE] + + # TAG_INDEXS = [idx['index_name'] for idx in nebula_graph_storage.INIT_VERTEX_INDEXES] + [idx['index_name'] for idx in nebula_graph_storage.COMMUNITY_VERTEX_INDEXES] + # EDGE_INDEXS =[idx['index_name'] for idx in nebula_graph_storage.INIT_EDGE_INDEXES] + [idx['index_name'] for idx in nebula_graph_storage.COMMUNITY_EDGE_INDEXES] + try: + for tag in TAG_NAMES: + delete_query = f"LOOKUP ON {tag} YIELD id(vertex) AS ID | DELETE VERTEX $-.ID WITH EDGE;" + result = nebula_graph_storage.client.execute_py(delete_query) + if not result.is_succeeded(): + raise RuntimeError(f"Failed to delete vertices: {result.error_msg()}") + + except Exception as e: + print(f"Error occurred while deleting data: {e}") + +def delete_space(nebula_graph_storage): + """Delete the space from the Nebula Graph database.""" + try: + # Get the current space name + space_name = nebula_graph_storage.space + + # Drop the specified space + drop_space_query = f"DROP SPACE IF EXISTS {space_name}" + result = nebula_graph_storage.client.execute_py(drop_space_query) + if not result.is_succeeded(): + raise RuntimeError(f"Failed to drop space: {result.error_msg()}") + + print(f"Successfully dropped space: {space_name}") + except Exception as e: + print(f"Error occurred while dropping space: {e}") + + +def reset_graph(func): + @wraps(func) + async def new_func(nebula_graph_storage, *args, **kwargs): + delete_all_data(nebula_graph_storage) + results = await func(nebula_graph_storage, *args, **kwargs) + delete_all_data(nebula_graph_storage) + return results + + return new_func + + +@pytest.mark.asyncio +@reset_graph +async def test_upsert_and_get_node(nebula_graph_storage): + node_id = "node1" + node_data = {"entity_name": "Entity 1", "description": "description111"} + return_data = {"id": node_id, **node_data} + + await nebula_graph_storage.upsert_node(node_id, node_data) + + result = await nebula_graph_storage.get_node(node_id) + assert all(result.get(key) == value for key, value in return_data.items()) + + has_node = await nebula_graph_storage.has_node(node_id) + assert has_node is True + + non_existent_node = await nebula_graph_storage.get_node("non_existent") + assert non_existent_node is None + + has_non_existent_node = await nebula_graph_storage.has_node("non_existent") + assert has_non_existent_node is False + + + + +@pytest.mark.asyncio +@reset_graph +async def test_upsert_and_get_edge(nebula_graph_storage): + source_id = "node1" + target_id = "node2" + node_data = {"entity_name": "Entity 1", "description": "description111"} + edge_data = {"weight": 1.0, "description": "connection"} + + await nebula_graph_storage.upsert_node(source_id, node_data) + await nebula_graph_storage.upsert_node(target_id, node_data) + await nebula_graph_storage.upsert_edge(source_id, target_id, edge_data) + + result = await nebula_graph_storage.get_edge(source_id, target_id) + assert all(result.get(key) == value for key, value in edge_data.items()) + + has_edge = await nebula_graph_storage.has_edge(source_id, target_id) + assert has_edge is True + + # 测试不存在的边 + non_existent_edge = await nebula_graph_storage.get_edge("non_existent1", "non_existent2") + assert non_existent_edge is None + + has_non_existent_edge = await nebula_graph_storage.has_edge("non_existent1", "non_existent2") + assert has_non_existent_edge is False + + + + +@pytest.mark.asyncio +@reset_graph +async def test_node_degree(nebula_graph_storage): + node_id = "center" + node_data = {"entity_name": "Entity 1", "description": "description111"} + edge_data = {"weight": 1.0, "description": "connection"} + await nebula_graph_storage.upsert_node(node_id, node_data) + + num_neighbors = 5 + for i in range(num_neighbors): + neighbor_id = f"neighbor{i}" + await nebula_graph_storage.upsert_node(neighbor_id, node_data) + await nebula_graph_storage.upsert_edge(node_id, neighbor_id, edge_data) + + degree = await nebula_graph_storage.node_degree(node_id) + assert degree == num_neighbors + + non_existent_degree = await nebula_graph_storage.node_degree("non_existent") + assert non_existent_degree == 0 + + + +@pytest.mark.asyncio +@reset_graph +async def test_edge_degree(nebula_graph_storage): + source_id = "node1" + target_id = "node2" + node_data = {"entity_name": "Entity 1", "description": "description111"} + edge_data = {"weight": 1.0, "description": "connection"} + + await nebula_graph_storage.upsert_node(source_id, node_data) + await nebula_graph_storage.upsert_node(target_id, node_data) + await nebula_graph_storage.upsert_edge(source_id, target_id, edge_data) + + num_source_neighbors = 3 + for i in range(num_source_neighbors): + neighbor_id = f"neighbor{i}" + await nebula_graph_storage.upsert_node(neighbor_id, node_data) + await nebula_graph_storage.upsert_edge(source_id, neighbor_id, edge_data) + + num_target_neighbors = 2 + for i in range(num_target_neighbors): + neighbor_id = f"target_neighbor{i}" + await nebula_graph_storage.upsert_node(neighbor_id, node_data) + await nebula_graph_storage.upsert_edge(target_id, neighbor_id, edge_data) + + expected_edge_degree = (num_source_neighbors + 1) + (num_target_neighbors + 1) + edge_degree = await nebula_graph_storage.edge_degree(source_id, target_id) + assert edge_degree == expected_edge_degree + + non_existent_edge_degree = await nebula_graph_storage.edge_degree("non_existent1", "non_existent2") + assert non_existent_edge_degree == 0 + + + + +@pytest.mark.asyncio +@reset_graph +async def test_get_node_edges(nebula_graph_storage): + center_id = "center" + await nebula_graph_storage.upsert_node(center_id, {"entity_name": "Center Node"}) + + expected_edges = [] + for i in range(3): + neighbor_id = f"neighbor{i}" + await nebula_graph_storage.upsert_node(neighbor_id, {"entity_name": f"Neighbor {i}"}) + await nebula_graph_storage.upsert_edge(center_id, neighbor_id, {"weight": 1.0, "description": "connection"}) + expected_edges.append((center_id, neighbor_id)) + + result = await nebula_graph_storage.get_node_edges(center_id) + assert all(any(edge[0] == r[0] and edge[1] == r[1] for r in result) for edge in expected_edges) + + +@pytest.mark.parametrize("algorithm", ["leiden"]) +@pytest.mark.asyncio +@reset_graph +async def test_clustering(nebula_graph_storage, algorithm): + for i in range(10): + await nebula_graph_storage.upsert_node(f"NODE{i}", {"source_id": f"chunk{i}"}) + + for i in range(9): + await nebula_graph_storage.upsert_edge(f"NODE{i}", f"NODE{i+1}", {"weight": 1.0}) + + await nebula_graph_storage.clustering(algorithm=algorithm) + + community_schema = await nebula_graph_storage.community_schema() + + assert len(community_schema) > 0 + + for community in community_schema.values(): + assert "level" in community + assert "title" in community + assert "edges" in community + assert "nodes" in community + assert "chunk_ids" in community + assert "occurrence" in community + assert "sub_communities" in community + + all_nodes = set() + for community in community_schema.values(): + all_nodes.update(community["nodes"]) + assert len(all_nodes) == 10 + + +@pytest.mark.parametrize("algorithm", ["leiden"]) +@pytest.mark.asyncio +@reset_graph +async def test_leiden_clustering_community_structure(nebula_graph_storage, algorithm): + for i in range(10): + await nebula_graph_storage.upsert_node(f"A{i}", {"source_id": f"chunkA{i}"}) + await nebula_graph_storage.upsert_node(f"B{i}", {"source_id": f"chunkB{i}"}) + for i in range(9): + await nebula_graph_storage.upsert_edge(f"A{i}", f"A{i+1}", {"weight": 1.0}) + await nebula_graph_storage.upsert_edge(f"B{i}", f"B{i+1}", {"weight": 1.0}) + + await nebula_graph_storage.clustering(algorithm=algorithm) + community_schema = await nebula_graph_storage.community_schema() + + assert len(community_schema) >= 2, "Should have at least two communities" + + communities = list(community_schema.values()) + a_nodes = set(node for node in communities[0]['nodes'] if node.startswith('A')) + b_nodes = set(node for node in communities[0]['nodes'] if node.startswith('B')) + assert len(a_nodes) == 0 or len(b_nodes) == 0, "Nodes from different groups should be in different communities" + + +@pytest.mark.parametrize("algorithm", ["leiden"]) +@pytest.mark.asyncio +@reset_graph +async def test_leiden_clustering_hierarchical_structure(nebula_graph_storage, algorithm): + await nebula_graph_storage.upsert_node("NODE1", {"source_id": "chunk1", "clusters": json.dumps([{"level": 0, "cluster": "0"}, {"level": 1, "cluster": "1"}])}) + await nebula_graph_storage.upsert_node("NODE2", {"source_id": "chunk2", "clusters": json.dumps([{"level": 0, "cluster": "0"}, {"level": 1, "cluster": "2"}])}) + await nebula_graph_storage.upsert_edge("NODE1", "NODE2", {"weight": 1.0}) + + await nebula_graph_storage.clustering(algorithm=algorithm) + community_schema = await nebula_graph_storage.community_schema() + + levels = set(community['level'] for community in community_schema.values()) + assert len(levels) >= 1, "Should have at least one level in the hierarchy" + + communities_per_level = {level: sum(1 for c in community_schema.values() if c['level'] == level) for level in levels} + assert communities_per_level[0] >= communities_per_level.get(max(levels), 0), "Lower levels should have more or equal number of communities" + + +@pytest.mark.asyncio +@reset_graph +async def test_error_handling(nebula_graph_storage): + with pytest.raises( + ValueError, match="Clustering algorithm invalid_algo not supported" + ): + await nebula_graph_storage.clustering("invalid_algo") + +@pytest.mark.asyncio +@reset_graph +async def test_index_done(nebula_graph_storage): + await nebula_graph_storage.index_done_callback() From b9281ec1cfe8260c3ddce8852a673be1454dc810 Mon Sep 17 00:00:00 2001 From: utopia2077 Date: Mon, 7 Oct 2024 04:18:18 +0800 Subject: [PATCH 2/3] chore: remove Chinese comments --- nano_graphrag/_storage/gdb_nebula.py | 43 +++++++++++----------------- tests/test_nebula_storage.py | 4 --- 2 files changed, 16 insertions(+), 31 deletions(-) diff --git a/nano_graphrag/_storage/gdb_nebula.py b/nano_graphrag/_storage/gdb_nebula.py index 5472d50..6001fea 100644 --- a/nano_graphrag/_storage/gdb_nebula.py +++ b/nano_graphrag/_storage/gdb_nebula.py @@ -594,10 +594,6 @@ async def clustering(self, algorithm: str): async def _cluster_data_to_graph(self, cluster_data: dict[str, list[dict[str, str]]]): - """ - 将社区数据转换为图数据,并写入 NebulaGraph - :param cluster_data: 社区数据,字典类型,键为 entity ID,值为社区数据列表 - """ # community node: (level, cluster), with id cluster_{cluster} # cluster_{cluster} is the key, and value is a list of (level, cluster) data: dict[str, list[int, int]] = defaultdict(list) @@ -628,7 +624,6 @@ async def update_node_clusters(node_id, clusters): update_tasks = [] for node_id, clusters in cluster_data.items(): - # TODO 确认这里的clusters只有数字 update_tasks.append(update_node_clusters(node_id, clusters)) await asyncio.gather(*update_tasks) @@ -691,20 +686,19 @@ async def _leiden_clustering(self): await self._cluster_data_to_graph(node_communities) async def community_schema(self) -> dict[str, SingleCommunitySchema]: - # 初始化一个默认字典,字典的默认值是一个包含社区信息的字典 results = defaultdict( lambda: dict( - level=None, # 社区的层级 - title=None, # 社区的标题 - edges=set(), # 社区内的边集合 - nodes=set(), # 社区内的节点集合 - chunk_ids=set(), # 社区内的分块ID集合 - occurrence=0.0, # 社区的出现频率 - sub_communities=[], # 子社区列表 + level=None, + title=None, + edges=set(), + nodes=set(), + chunk_ids=set(), + occurrence=0.0, + sub_communities=[], ) ) - max_num_ids = 0 # 初始化最大ID数量 - levels = defaultdict(set) # 初始化一个默认字典,用于存储每个层级的社区 + max_num_ids = 0 + levels = defaultdict(set) communities_result = self.client.execute_py( @@ -719,7 +713,6 @@ async def community_schema(self) -> dict[str, SingleCommunitySchema]: community_key = str(community_properties.get("cluster")) title = community_id - # 获取community连接的node node_query = f"MATCH (n:{self.INIT_VERTEX_TYPE})-[:{self.COMMUNITY_EDGE_TYPE}]->(c:{self.COMMUNITY_VERTEX_TYPE}) WHERE id(c) == '{community_id}' RETURN n;" nodes_in_community = self.client.execute_py(node_query) nodes = set() @@ -732,7 +725,6 @@ async def community_schema(self) -> dict[str, SingleCommunitySchema]: node_properties = {k: v.cast() for k, v in node_data.properties().items()} chunk_ids.update(node_properties.get("source_id", "").split(GRAPH_FIELD_SEP)) - # 获取node连接的边, 处理成(src,tgt)的形式 edge_query = f"MATCH (n:{self.INIT_VERTEX_TYPE})-[:{self.INIT_EDGE_TYPE}]-(m:{self.INIT_VERTEX_TYPE}) WHERE id(n) == '{escape_bucket(node_id)}' RETURN m;" edges_in_node = self.client.execute_py(edge_query) for edge in edges_in_node.column_values("m"): @@ -751,28 +743,25 @@ async def community_schema(self) -> dict[str, SingleCommunitySchema]: levels[level].add(community_key) max_num_ids = max(max_num_ids, len(chunk_ids)) - ordered_levels = sorted(levels.keys()) # 对层级进行排序 - # 遍历排序后的层级,计算子社区 + ordered_levels = sorted(levels.keys()) for i, curr_level in enumerate(ordered_levels[:-1]): - next_level = ordered_levels[i + 1] # 获取下一个层级 - this_level_comms = levels[curr_level] # 获取当前层级的社区 - next_level_comms = levels[next_level] # 获取下一个层级的社区 + next_level = ordered_levels[i + 1] + this_level_comms = levels[curr_level] + next_level_comms = levels[next_level] for comm in this_level_comms: - # 遍历当前层级的社区,计算子社区 results[comm]["sub_communities"] = [ c for c in next_level_comms - if results[c]["nodes"].issubset(results[comm]["nodes"]) # 如果下一个层级的社区节点是当前层级社区节点的子集,则将其添加为子社区 + if results[c]["nodes"].issubset(results[comm]["nodes"]) ] - - # 将集合类型的字段转换为列表,并计算社区的出现频率 + for k, v in results.items(): v["edges"] = list(v["edges"]) v["edges"] = [list(e) for e in v["edges"]] v["nodes"] = list(v["nodes"]) v["chunk_ids"] = list(v["chunk_ids"]) - v["occurrence"] = len(v["chunk_ids"]) / max_num_ids # 计算社区的出现频率 + v["occurrence"] = len(v["chunk_ids"]) / max_num_ids return dict(results) async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: diff --git a/tests/test_nebula_storage.py b/tests/test_nebula_storage.py index 595b50e..fcbb64a 100644 --- a/tests/test_nebula_storage.py +++ b/tests/test_nebula_storage.py @@ -63,10 +63,7 @@ def test_nebula_storage_init(): def delete_all_data(nebula_graph_storage): """Delete all tags and edges from the Nebula Graph database.""" TAG_NAMES = [nebula_graph_storage.INIT_VERTEX_TYPE, nebula_graph_storage.COMMUNITY_VERTEX_TYPE] - # EDGE_NAMES = [nebula_graph_storage.COMMUNITY_EDGE_TYPE, nebula_graph_storage.COMMUNITY_EDGE_TYPE] - # TAG_INDEXS = [idx['index_name'] for idx in nebula_graph_storage.INIT_VERTEX_INDEXES] + [idx['index_name'] for idx in nebula_graph_storage.COMMUNITY_VERTEX_INDEXES] - # EDGE_INDEXS =[idx['index_name'] for idx in nebula_graph_storage.INIT_EDGE_INDEXES] + [idx['index_name'] for idx in nebula_graph_storage.COMMUNITY_EDGE_INDEXES] try: for tag in TAG_NAMES: delete_query = f"LOOKUP ON {tag} YIELD id(vertex) AS ID | DELETE VERTEX $-.ID WITH EDGE;" @@ -147,7 +144,6 @@ async def test_upsert_and_get_edge(nebula_graph_storage): has_edge = await nebula_graph_storage.has_edge(source_id, target_id) assert has_edge is True - # 测试不存在的边 non_existent_edge = await nebula_graph_storage.get_edge("non_existent1", "non_existent2") assert non_existent_edge is None From 7be13e1579445f6fccd2f89c92ce994d3d0d2757 Mon Sep 17 00:00:00 2001 From: utopia2077 Date: Mon, 7 Oct 2024 07:39:43 +0800 Subject: [PATCH 3/3] build: add dependencies --- requirements.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index be0e993..667c947 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,8 @@ hnswlib xxhash tenacity dspy-ai -neo4j \ No newline at end of file +neo4j +pymilvus +redis +nebula3-python +ng_nx \ No newline at end of file