diff --git a/.env.template b/.env.template index 598407456..84ebe84a6 100644 --- a/.env.template +++ b/.env.template @@ -157,12 +157,15 @@ EXECUTE_LOCAL_COMMANDS=False #*******************************************************************# VECTOR_STORE_TYPE=Chroma GRAPH_STORE_TYPE=TuGraph -GRAPH_COMMUNITY_SUMMARY_ENABLED=True KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5 KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3 KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20 KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0 +ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary +ENABLE_TRIPLET_GRAPH=True # enable the graph search for triplets +ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks + ### Chroma vector db config #CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 99eb0e2eb..bbdf59953 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -213,8 +213,8 @@ def __init__(self) -> None: # Vector Store Configuration self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma") - self.GRAPH_COMMUNITY_SUMMARY_ENABLED = ( - os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true" + self.ENABLE_GRAPH_COMMUNITY_SUMMARY = ( + os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true" ) self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1") self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index 0aecac6e7..9904ecb36 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -12,7 +12,6 @@ KnowledgeDocumentEntity, ) from dbgpt.app.knowledge.request.request import ( - ChunkEditRequest, ChunkQueryRequest, DocumentQueryRequest, DocumentRecallTestRequest, @@ -650,12 +649,17 @@ def query_graph(self, space_name, limit): { "id": node.vid, "communityId": node.get_prop("_community_id"), - "name": node.vid, - "type": "", + "name": node.name, + "type": node.get_prop("type") or "", } ) for edge in graph.edges(): res["edges"].append( - {"source": edge.sid, "target": edge.tid, "name": edge.name, "type": ""} + { + "source": edge.sid, + "target": edge.tid, + "name": edge.name, + "type": edge.get_prop("type") or "", + } ) return res diff --git a/dbgpt/datasource/conn_tugraph.py b/dbgpt/datasource/conn_tugraph.py index 18c2dd9cd..191bfea87 100644 --- a/dbgpt/datasource/conn_tugraph.py +++ b/dbgpt/datasource/conn_tugraph.py @@ -1,7 +1,7 @@ """TuGraph Connector.""" import json -from typing import Dict, Generator, List, cast +from typing import Dict, Generator, List, Tuple, cast from .base import BaseConnector @@ -21,8 +21,7 @@ def __init__(self, driver, graph): self._session = None def create_graph(self, graph_name: str) -> None: - """Create a new graph.""" - # run the query to get vertex labels + """Create a new graph in the database if it doesn't already exist.""" try: with self._driver.session(database="default") as session: graph_list = session.run("CALL dbms.graph.listGraphs()").data() @@ -32,10 +31,10 @@ def create_graph(self, graph_name: str) -> None: f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)" ) except Exception as e: - raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") + raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e def delete_graph(self, graph_name: str) -> None: - """Delete a graph.""" + """Delete a graph in the database if it exists.""" with self._driver.session(database="default") as session: graph_list = session.run("CALL dbms.graph.listGraphs()").data() exists = any(item["graph_name"] == graph_name for item in graph_list) @@ -61,17 +60,20 @@ def from_uri_db( "`pip install neo4j`" ) from err - def get_table_names(self) -> Dict[str, List[str]]: + def get_table_names(self) -> Tuple[List[str], List[str]]: """Get all table names from the TuGraph by Neo4j driver.""" - # run the query to get vertex labels with self._driver.session(database=self._graph) as session: - v_result = session.run("CALL db.vertexLabels()").data() - v_data = [table_name["label"] for table_name in v_result] + # Run the query to get vertex labels + raw_vertex_labels: Dict[str, str] = session.run( + "CALL db.vertexLabels()" + ).data() + vertex_labels = [table_name["label"] for table_name in raw_vertex_labels] + + # Run the query to get edge labels + raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data() + edge_labels = [table_name["label"] for table_name in raw_edge_labels] - # run the query to get edge labels - e_result = session.run("CALL db.edgeLabels()").data() - e_data = [table_name["label"] for table_name in e_result] - return {"vertex_tables": v_data, "edge_tables": e_data} + return vertex_labels, edge_labels def get_grants(self): """Get grants.""" @@ -100,7 +102,7 @@ def run(self, query: str, fetch: str = "all") -> List: result = session.run(query) return list(result) except Exception as e: - raise Exception(f"Query execution failed: {e}") + raise Exception(f"Query execution failed: {e}\nQuery: {query}") from e def run_stream(self, query: str) -> Generator: """Run GQL.""" @@ -109,11 +111,15 @@ def run_stream(self, query: str) -> Generator: yield from result def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]: - """Get fields about specified graph. + """Retrieve the column for a specified vertex or edge table in the graph db. + + This function queries the schema of a given table (vertex or edge) and returns + detailed information about its columns (properties). Args: table_name (str): table name (graph name) table_type (str): table type (vertex or edge) + Returns: columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str @@ -146,8 +152,8 @@ def get_indexes(self, table_name: str, table_type: str = "vertex") -> List[Dict] """Get table indexes about specified table. Args: - table_name:(str) table name - table_type:(str)'vertex' | 'edge' + table_name (str): table name + table_type (str): 'vertex' | 'edge' Returns: List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}] """ diff --git a/dbgpt/rag/transformer/graph_extractor.py b/dbgpt/rag/transformer/graph_extractor.py index 18e867683..12751e89f 100644 --- a/dbgpt/rag/transformer/graph_extractor.py +++ b/dbgpt/rag/transformer/graph_extractor.py @@ -65,7 +65,9 @@ def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph] match = re.match(r"\((.*?)#(.*?)\)", line) if match: name, summary = [part.strip() for part in match.groups()] - graph.upsert_vertex(Vertex(name, description=summary)) + graph.upsert_vertex( + Vertex(name, description=summary, vertex_type="entity") + ) elif current_section == "Relationships": match = re.match(r"\((.*?)#(.*?)#(.*?)#(.*?)\)", line) if match: @@ -74,7 +76,13 @@ def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph] ] edge_count += 1 graph.append_edge( - Edge(source, target, name, description=summary) + Edge( + source, + target, + name, + description=summary, + edge_type="relation", + ) ) if limit and edge_count >= limit: diff --git a/dbgpt/rag/transformer/keyword_extractor.py b/dbgpt/rag/transformer/keyword_extractor.py index dec5f14a0..9a0776e13 100644 --- a/dbgpt/rag/transformer/keyword_extractor.py +++ b/dbgpt/rag/transformer/keyword_extractor.py @@ -1,4 +1,5 @@ """KeywordExtractor class.""" + import logging from typing import List, Optional @@ -39,12 +40,15 @@ def __init__(self, llm_client: LLMClient, model_name: str): def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]: keywords = set() - for part in text.split(";"): - for s in part.strip().split(","): - keyword = s.strip() - if keyword: - keywords.add(keyword) - if limit and len(keywords) >= limit: - return list(keywords) + lines = text.replace(":", "\n").split("\n") + + for line in lines: + for part in line.split(";"): + for s in part.strip().split(","): + keyword = s.strip() + if keyword: + keywords.add(keyword) + if limit and len(keywords) >= limit: + return list(keywords) return list(keywords) diff --git a/dbgpt/serve/rag/connector.py b/dbgpt/serve/rag/connector.py index ae7cf1773..4601ae9ae 100644 --- a/dbgpt/serve/rag/connector.py +++ b/dbgpt/serve/rag/connector.py @@ -128,7 +128,7 @@ def __init__( def __rewrite_index_store_type(self, index_store_type): # Rewrite Knowledge Graph Type - if CFG.GRAPH_COMMUNITY_SUMMARY_ENABLED: + if CFG.ENABLE_GRAPH_COMMUNITY_SUMMARY: if index_store_type == "KnowledgeGraph": return "CommunitySummaryKnowledgeGraph" return index_store_type diff --git a/dbgpt/storage/graph_store/base.py b/dbgpt/storage/graph_store/base.py index 24a4b467b..8c0454425 100644 --- a/dbgpt/storage/graph_store/base.py +++ b/dbgpt/storage/graph_store/base.py @@ -1,11 +1,11 @@ """Graph store base class.""" + import logging from abc import ABC, abstractmethod -from typing import Generator, List, Optional, Tuple +from typing import Optional from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core import Embeddings -from dbgpt.storage.graph_store.graph import Direction, Graph logger = logging.getLogger(__name__) @@ -23,78 +23,36 @@ class GraphStoreConfig(BaseModel): default=None, description="The embedding function of graph store, optional.", ) - summary_enabled: bool = Field( + enable_summary: bool = Field( default=False, description="Enable graph community summary or not.", ) + enable_document_graph: bool = Field( + default=True, + description="Enable document graph search or not.", + ) + enable_triplet_graph: bool = Field( + default=True, + description="Enable knowledge graph search or not.", + ) class GraphStoreBase(ABC): """Graph store base class.""" + def __init__(self, config: GraphStoreConfig): + """Initialize graph store.""" + self._config = config + self._conn = None + @abstractmethod def get_config(self) -> GraphStoreConfig: """Get the graph store config.""" @abstractmethod - def get_vertex_type(self) -> str: - """Get the vertex type.""" - - @abstractmethod - def get_edge_type(self) -> str: - """Get the edge type.""" - - @abstractmethod - def insert_triplet(self, sub: str, rel: str, obj: str): - """Add triplet.""" - - @abstractmethod - def insert_graph(self, graph: Graph): - """Add graph.""" - - @abstractmethod - def get_triplets(self, sub: str) -> List[Tuple[str, str]]: - """Get triplets.""" - - @abstractmethod - def delete_triplet(self, sub: str, rel: str, obj: str): - """Delete triplet.""" + def _escape_quotes(self, text: str) -> str: + """Escape single and double quotes in a string for queries.""" - @abstractmethod - def truncate(self): - """Truncate Graph.""" - - @abstractmethod - def drop(self): - """Drop graph.""" - - @abstractmethod - def get_schema(self, refresh: bool = False) -> str: - """Get schema.""" - - @abstractmethod - def get_full_graph(self, limit: Optional[int] = None) -> Graph: - """Get full graph.""" - - @abstractmethod - def explore( - self, - subs: List[str], - direct: Direction = Direction.BOTH, - depth: Optional[int] = None, - fan: Optional[int] = None, - limit: Optional[int] = None, - ) -> Graph: - """Explore on graph.""" - - @abstractmethod - def query(self, query: str, **args) -> Graph: - """Execute a query.""" - - def aquery(self, query: str, **args) -> Graph: - """Async execute a query.""" - return self.query(query, **args) - - @abstractmethod - def stream_query(self, query: str) -> Generator[Graph, None, None]: - """Execute stream query.""" + # @abstractmethod + # def _paser(self, entities: List[Vertex]) -> str: + # """Parse entities to string.""" diff --git a/dbgpt/storage/graph_store/factory.py b/dbgpt/storage/graph_store/factory.py index cf598966a..a7564d0c9 100644 --- a/dbgpt/storage/graph_store/factory.py +++ b/dbgpt/storage/graph_store/factory.py @@ -1,4 +1,5 @@ """Graph store factory.""" + import logging from typing import Tuple, Type diff --git a/dbgpt/storage/graph_store/graph.py b/dbgpt/storage/graph_store/graph.py index 555bcc14d..2807afda1 100644 --- a/dbgpt/storage/graph_store/graph.py +++ b/dbgpt/storage/graph_store/graph.py @@ -1,4 +1,5 @@ """Graph definition.""" + import itertools import json import logging @@ -6,13 +7,41 @@ from abc import ABC, abstractmethod from collections import defaultdict from enum import Enum -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple import networkx as nx logger = logging.getLogger(__name__) +class GraphElemType(Enum): + """Type of element in graph.""" + + DOCUMENT = "document" + CHUNK = "chunk" + ENTITY = "entity" # view as general vertex in the general case + RELATION = "relation" # view as general edge in the general case + INCLUDE = "include" + NEXT = "next" + + DOCUMENT_INCLUDE_CHUNK = "document_include_chunk" + CHUNK_INCLUDE_CHUNK = "chunk_include_chunk" + CHUNK_INCLUDE_ENTITY = "chunk_include_entity" + CHUNK_NEXT_CHUNK = "chunk_next_chunk" + + def is_vertex(self) -> bool: + """Check if the element is a vertex.""" + return self in [ + GraphElemType.DOCUMENT, + GraphElemType.CHUNK, + GraphElemType.ENTITY, + ] + + def is_edge(self) -> bool: + """Check if the element is an edge.""" + return not self.is_vertex() + + class Direction(Enum): """Direction class.""" @@ -41,7 +70,7 @@ def props(self) -> Dict[str, Any]: def set_prop(self, key: str, value: Any): """Set a property of ELem.""" - self._props[key] = value + self._props[key] = value # note: always update the value def get_prop(self, key: str): """Get one of the properties of Elem.""" @@ -124,6 +153,18 @@ def __init__(self, sid: str, tid: str, name: str, **props): for k, v in props.items(): self.set_prop(k, v) + def __eq__(self, other): + """Check if two edges are equal. + + Let's say two edges are equal if they have the same source vertex ID, + target vertex ID, and edge label. The properties are not considered. + """ + return (self.sid, self.tid, self.name) == (other.sid, other.tid, other.name) + + def __hash__(self): + """Return the hash value of the edge.""" + return hash((self.sid, self.tid, self.name)) + @property def sid(self) -> str: """Return the source vertex ID of the edge.""" @@ -188,11 +229,15 @@ def get_neighbor_edges( """Get neighbor edges.""" @abstractmethod - def vertices(self) -> Iterator[Vertex]: + def vertices( + self, filter_fn: Optional[Callable[[Vertex], bool]] = None + ) -> Iterator[Vertex]: """Get vertex iterator.""" @abstractmethod - def edges(self) -> Iterator[Edge]: + def edges( + self, filter_fn: Optional[Callable[[Edge], bool]] = None + ) -> Iterator[Edge]: """Get edge iterator.""" @abstractmethod @@ -241,7 +286,7 @@ def __init__(self): self._edge_prop_keys = set() self._edge_count = 0 - # init vertices, out edges, in edges index + # vertices index, out edges index, in edges index self._vs: Any = defaultdict() self._oes: Any = defaultdict(lambda: defaultdict(set)) self._ies: Any = defaultdict(lambda: defaultdict(set)) @@ -269,7 +314,7 @@ def upsert_vertex(self, vertex: Vertex): # update metadata self._vertex_prop_keys.update(vertex.props.keys()) - def append_edge(self, edge: Edge): + def append_edge(self, edge: Edge) -> bool: """Append an edge if it doesn't exist; requires edge label.""" sid = edge.sid tid = edge.tid @@ -290,6 +335,34 @@ def append_edge(self, edge: Edge): self._edge_count += 1 return True + def upsert_vertex_and_edge( + self, + src_vid: str, + src_name: str, + src_props: Dict[str, Any], + dst_vid: str, + dst_name: str, + dst_props: Dict[str, Any], + edge_name: str, + edge_type: str, + ): + """Uperst src and dst vertex, and edge.""" + src_vertex = Vertex(src_vid, src_name, **src_props) + dst_vertex = Vertex(dst_vid, dst_name, **dst_props) + edge = Edge(src_vid, dst_vid, edge_name, **{"edge_type": edge_type}) + + self.upsert_vertex(src_vertex) + self.upsert_vertex(dst_vertex) + self.append_edge(edge) + + def upsert_graph(self, graph: "MemoryGraph"): + """Upsert a graph.""" + for vertex in graph.vertices(): + self.upsert_vertex(vertex) + + for edge in graph.edges(): + self.append_edge(edge) + def has_vertex(self, vid: str) -> bool: """Retrieve a vertex by ID.""" return vid in self._vs @@ -335,13 +408,26 @@ def unique_elements(elements): return itertools.islice(es, limit) if limit else es - def vertices(self) -> Iterator[Vertex]: + def vertices( + self, filter_fn: Optional[Callable[[Vertex], bool]] = None + ) -> Iterator[Vertex]: """Return vertices.""" - return iter(self._vs.values()) + # Get all vertices in the graph + all_vertices = self._vs.values() + + return all_vertices if filter_fn is None else filter(filter_fn, all_vertices) - def edges(self) -> Iterator[Edge]: + def edges( + self, filter_fn: Optional[Callable[[Edge], bool]] = None + ) -> Iterator[Edge]: """Return edges.""" - return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es) + # Get all edges in the graph + all_edges = (e for nbs in self._oes.values() for es in nbs.values() for e in es) + + if filter_fn is None: + return all_edges + else: + return filter(filter_fn, all_edges) def del_vertices(self, *vids: str): """Delete specified vertices.""" @@ -353,7 +439,7 @@ def del_edges(self, sid: str, tid: str, name: str, **props): """Delete edges.""" old_edge_cnt = len(self._oes[sid][tid]) - def remove_matches(es): + def remove_matches(es: Set[Edge]): return set( filter( lambda e: not ( @@ -452,7 +538,7 @@ def schema(self) -> Dict[str, Any]: ] } - def format(self) -> str: + def format(self, entities_only: Optional[bool] = False) -> str: """Format graph to string.""" vs_str = "\n".join(v.format() for v in self.vertices()) es_str = "\n".join( @@ -461,11 +547,14 @@ def format(self) -> str: f"{self.get_vertex(e.tid).format(concise=True)}" for e in self.edges() ) - return ( - f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}" - if (vs_str or es_str) - else "" - ) + if entities_only: + return f"Entities:\n{vs_str}" if vs_str else "" + else: + return ( + f"Entities:\n{vs_str}\n\nRelationships:\n{es_str}" + if (vs_str or es_str) + else "" + ) def truncate(self): """Truncate graph.""" diff --git a/dbgpt/storage/graph_store/memgraph_store.py b/dbgpt/storage/graph_store/memgraph_store.py index 1a6dd2c42..36d720829 100644 --- a/dbgpt/storage/graph_store/memgraph_store.py +++ b/dbgpt/storage/graph_store/memgraph_store.py @@ -1,11 +1,10 @@ """Memory graph store.""" -import json + import logging -from typing import Generator, List, Optional, Tuple from dbgpt._private.pydantic import ConfigDict from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig -from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph +from dbgpt.storage.graph_store.graph import MemoryGraph logger = logging.getLogger(__name__) @@ -28,77 +27,8 @@ def get_config(self): """Get the graph store config.""" return self._graph_store_config - def get_edge_type(self) -> str: - """Get the edge type.""" - raise NotImplementedError("Memory graph store does not have edge type") - - def get_vertex_type(self) -> str: - """Get the vertex type.""" - raise NotImplementedError("Memory graph store does not have vertex type") - - def insert_triplet(self, sub: str, rel: str, obj: str): - """Insert a triplet into the graph.""" - self._graph.append_edge(Edge(sub, obj, rel)) - - def insert_graph(self, graph: Graph): - """Add graph.""" - for vertex in graph.vertices(): - self._graph.upsert_vertex(vertex) - - for edge in graph.edges(): - self._graph.append_edge(edge) - - def get_triplets(self, sub: str) -> List[Tuple[str, str]]: - """Retrieve triplets originating from a subject.""" - subgraph = self.explore([sub], direct=Direction.OUT, depth=1) - return [(e.name, e.tid) for e in subgraph.edges()] - - def delete_triplet(self, sub: str, rel: str, obj: str): - """Delete a specific triplet from the graph.""" - self._graph.del_edges(sub, obj, rel) - - def truncate(self): - """Truncate graph.""" - self._graph.truncate() - - def drop(self): - """Drop graph.""" - self._graph = None - - def get_schema(self, refresh: bool = False) -> str: - """Return the graph schema as a JSON string.""" - return json.dumps(self._graph.schema()) - - def get_full_graph(self, limit: Optional[int] = None) -> Graph: - """Return self.""" - if not limit: - return self._graph - - subgraph = MemoryGraph() - for count, edge in enumerate(self._graph.edges()): - if count >= limit: - break - subgraph.upsert_vertex(self._graph.get_vertex(edge.sid)) - subgraph.upsert_vertex(self._graph.get_vertex(edge.tid)) - subgraph.append_edge(edge) - count += 1 - return subgraph - - def explore( - self, - subs: List[str], - direct: Direction = Direction.BOTH, - depth: Optional[int] = None, - fan: Optional[int] = None, - limit: Optional[int] = None, - ) -> MemoryGraph: - """Explore the graph from given subjects up to a depth.""" - return self._graph.search(subs, direct, depth, fan, limit) - - def query(self, query: str, **args) -> Graph: - """Execute a query on graph.""" - raise NotImplementedError("Query memory graph not allowed") - - def stream_query(self, query: str) -> Generator[Graph, None, None]: - """Execute stream query.""" - raise NotImplementedError("Stream query memory graph not allowed") + def _escape_quotes(self, text: str) -> str: + """Escape single and double quotes in a string for queries.""" + raise NotImplementedError( + "_escape_quotes is not implemented by MemoryGraphStore" + ) diff --git a/dbgpt/storage/graph_store/tugraph_store.py b/dbgpt/storage/graph_store/tugraph_store.py index 3fdd2df8b..2eb7af91b 100644 --- a/dbgpt/storage/graph_store/tugraph_store.py +++ b/dbgpt/storage/graph_store/tugraph_store.py @@ -1,14 +1,15 @@ """TuGraph store.""" + import base64 import json import logging import os -from typing import Any, Generator, Iterator, List, Optional, Tuple +from typing import List from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.datasource.conn_tugraph import TuGraphConnector from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig -from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph, Vertex +from dbgpt.storage.graph_store.graph import GraphElemType logger = logging.getLogger(__name__) @@ -35,12 +36,28 @@ class TuGraphStoreConfig(GraphStoreConfig): description="login password", ) vertex_type: str = Field( - default="entity", - description="The type of vertex, `entity` by default.", + default=GraphElemType.ENTITY.value, + description="The type of entity vertex, `entity` by default.", + ) + document_type: str = Field( + default=GraphElemType.DOCUMENT.value, + description="The type of document vertex, `document` by default.", + ) + chunk_type: str = Field( + default=GraphElemType.CHUNK.value, + description="The type of chunk vertex, `relation` by default.", ) edge_type: str = Field( - default="relation", - description="The type of edge, `relation` by default.", + default=GraphElemType.RELATION.value, + description="The type of relation edge, `relation` by default.", + ) + include_type: str = Field( + default=GraphElemType.INCLUDE.value, + description="The type of include edge, `include` by default.", + ) + next_type: str = Field( + default=GraphElemType.NEXT.value, + description="The type of next edge, `next` by default.", ) plugin_names: List[str] = Field( default=["leiden"], @@ -62,17 +79,24 @@ def __init__(self, config: TuGraphStoreConfig) -> None: self._port = int(os.getenv("TUGRAPH_PORT", config.port)) self._username = os.getenv("TUGRAPH_USERNAME", config.username) self._password = os.getenv("TUGRAPH_PASSWORD", config.password) - self._summary_enabled = ( - os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true" - or config.summary_enabled + self._enable_summary = ( + os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true" + or config.enable_summary + ) + self._enable_document_graph = ( + os.getenv("ENABLE_DOCUMENT_GRAPH", "").lower() == "true" + or config.enable_document_graph + ) + self._enable_triplet_graph = ( + os.getenv("ENABLE_TRIPLET_GRAPH", "").lower() == "true" + or config.enable_triplet_graph ) self._plugin_names = ( os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",") or config.plugin_names ) + self._graph_name = config.name - self._vertex_type = os.getenv("TUGRAPH_VERTEX_TYPE", config.vertex_type) - self._edge_type = os.getenv("TUGRAPH_EDGE_TYPE", config.edge_type) self.conn = TuGraphConnector.from_uri_db( host=self._host, @@ -82,34 +106,29 @@ def __init__(self, config: TuGraphStoreConfig) -> None: db_name=config.name, ) - self._create_graph(config.name) - - def get_vertex_type(self) -> str: - """Get the vertex type.""" - return self._vertex_type - - def get_edge_type(self) -> str: - """Get the edge type.""" - return self._edge_type - - def _create_graph(self, graph_name: str): - self.conn.create_graph(graph_name=graph_name) - self._create_schema() - if self._summary_enabled: - self._upload_plugin() - - def _check_label(self, elem_type: str): - result = self.conn.get_table_names() - if elem_type == "vertex": - return self._vertex_type in result["vertex_tables"] - if elem_type == "edge": - return self._edge_type in result["edge_tables"] + def get_config(self) -> TuGraphStoreConfig: + """Get the TuGraph store config.""" + return self._config def _add_vertex_index(self, field_name): - gql = f"CALL db.addIndex('{self._vertex_type}', '{field_name}', false)" + """Add an index to the vertex table.""" + # TODO: Not used in the current implementation. + gql = f"CALL db.addIndex('{GraphElemType.ENTITY.value}', '{field_name}', false)" self.conn.run(gql) def _upload_plugin(self): + """Upload missing plugins to the TuGraph database. + + This method checks for the presence of required plugins in the database and + uploads any missing plugins. It performs the following steps: + 1. Lists existing plugins in the database. + 2. Identifies missing plugins by comparing with the required plugin list. + 3. For each missing plugin, reads its binary content, encodes it, and uploads to + the database. + + The method uses the 'leiden' plugin as an example, but can be extended for other + plugins. + """ gql = "CALL db.plugin.listPlugin('CPP','v1')" result = self.conn.run(gql) result_names = [ @@ -122,8 +141,8 @@ def _upload_plugin(self): if len(missing_plugins): for name in missing_plugins: try: - from dbgpt_tugraph_plugins import ( # type: ignore # noqa - get_plugin_binary_path, + from dbgpt_tugraph_plugins import ( + get_plugin_binary_path, # type:ignore[import-untyped] ) except ImportError: logger.error( @@ -136,375 +155,12 @@ def _upload_plugin(self): content = f.read() content = base64.b64encode(content).decode() gql = ( - f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', " - "'SO', '{name} Plugin', false, 'v1')" + f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', 'SO', " + f"'{name} Plugin', false, 'v1')" ) self.conn.run(gql) - def _create_schema(self): - if not self._check_label("vertex"): - if self._summary_enabled: - create_vertex_gql = ( - f"CALL db.createLabel(" - f"'vertex', '{self._vertex_type}', " - f"'id', ['id',string,false]," - f"['name',string,false]," - f"['_document_id',string,true]," - f"['_chunk_id',string,true]," - f"['_community_id',string,true]," - f"['description',string,true])" - ) - self.conn.run(create_vertex_gql) - self._add_vertex_index("_community_id") - else: - create_vertex_gql = ( - f"CALL db.createLabel(" - f"'vertex', '{self._vertex_type}', " - f"'id', ['id',string,false]," - f"['name',string,false])" - ) - self.conn.run(create_vertex_gql) - - if not self._check_label("edge"): - create_edge_gql = f"""CALL db.createLabel( - 'edge', '{self._edge_type}', - '[["{self._vertex_type}", - "{self._vertex_type}"]]', - ["id",STRING,false], - ["name",STRING,false])""" - if self._summary_enabled: - create_edge_gql = f"""CALL db.createLabel( - 'edge', '{self._edge_type}', - '[["{self._vertex_type}", - "{self._vertex_type}"]]', - ["id",STRING,false], - ["name",STRING,false], - ["description",STRING,true])""" - self.conn.run(create_edge_gql) - - def _format_query_data(self, data, white_prop_list: List[str]): - nodes_list = [] - rels_list: List[Any] = [] - _white_list = white_prop_list - from neo4j import graph - - def get_filtered_properties(properties, white_list): - return { - key: value - for key, value in properties.items() - if (not key.startswith("_") and key not in ["id", "name"]) - or key in white_list - } - - def process_node(node: graph.Node): - node_id = node._properties.get("id") - node_name = node._properties.get("name") - node_properties = get_filtered_properties(node._properties, _white_list) - nodes_list.append( - {"id": node_id, "name": node_name, "properties": node_properties} - ) - - def process_relationship(rel: graph.Relationship): - name = rel._properties.get("name", "") - rel_nodes = rel.nodes - src_id = rel_nodes[0]._properties.get("id") - dst_id = rel_nodes[1]._properties.get("id") - for node in rel_nodes: - process_node(node) - edge_properties = get_filtered_properties(rel._properties, _white_list) - if not any( - existing_edge.get("name") == name - and existing_edge.get("src_id") == src_id - and existing_edge.get("dst_id") == dst_id - for existing_edge in rels_list - ): - rels_list.append( - { - "src_id": src_id, - "dst_id": dst_id, - "name": name, - "properties": edge_properties, - } - ) - - def process_path(path: graph.Path): - for rel in path.relationships: - process_relationship(rel) - - def process_other(value): - if not any( - existing_node.get("id") == "json_node" for existing_node in nodes_list - ): - nodes_list.append( - { - "id": "json_node", - "name": "json_node", - "properties": {"description": value}, - } - ) - - for record in data: - for key in record.keys(): - value = record[key] - if isinstance(value, graph.Node): - process_node(value) - elif isinstance(value, graph.Relationship): - process_relationship(value) - elif isinstance(value, graph.Path): - process_path(value) - else: - process_other(value) - nodes = [ - Vertex(node["id"], node["name"], **node["properties"]) - for node in nodes_list - ] - rels = [ - Edge(edge["src_id"], edge["dst_id"], edge["name"], **edge["properties"]) - for edge in rels_list - ] - return {"nodes": nodes, "edges": rels} - - def get_config(self): - """Get the graph store config.""" - return self._config - - def get_triplets(self, subj: str) -> List[Tuple[str, str]]: - """Get triplets.""" - query = ( - f"MATCH (n1:{self._vertex_type})-[r]->(n2:{self._vertex_type}) " - f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;' - ) - data = self.conn.run(query) - return [(record["rel"], record["obj"]) for record in data] - - def insert_triplet(self, subj: str, rel: str, obj: str) -> None: - """Add triplet.""" - - def escape_quotes(value: str) -> str: - """Escape single and double quotes in a string for queries.""" - return value.replace("'", "\\'").replace('"', '\\"') - - subj_escaped = escape_quotes(subj) - rel_escaped = escape_quotes(rel) - obj_escaped = escape_quotes(obj) - - node_query = f"""CALL db.upsertVertex( - '{self._vertex_type}', - [{{id:'{subj_escaped}',name:'{subj_escaped}'}}, - {{id:'{obj_escaped}',name:'{obj_escaped}'}}])""" - edge_query = f"""CALL db.upsertEdge( - '{self._edge_type}', - {{type:"{self._vertex_type}",key:"sid"}}, - {{type:"{self._vertex_type}", key:"tid"}}, - [{{sid:"{subj_escaped}", - tid: "{obj_escaped}", - id:"{rel_escaped}", - name: "{rel_escaped}"}}])""" - self.conn.run(query=node_query) - self.conn.run(query=edge_query) - - def insert_graph(self, graph: Graph) -> None: - """Add graph.""" - - def escape_quotes(value: str) -> str: - """Escape single and double quotes in a string for queries.""" - if value is not None: - return value.replace("'", "").replace('"', "") - - nodes: Iterator[Vertex] = graph.vertices() - edges: Iterator[Edge] = graph.edges() - node_list = [] - edge_list = [] - - def parser(node_list): - formatted_nodes = [ - "{" - + ", ".join( - f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}" - for k, v in node.items() - ) - + "}" - for node in node_list - ] - return f"""{', '.join(formatted_nodes)}""" - - for node in nodes: - node_list.append( - { - "id": escape_quotes(node.vid), - "name": escape_quotes(node.name), - "description": escape_quotes(node.get_prop("description")) or "", - "_document_id": "0", - "_chunk_id": "0", - "_community_id": "0", - } - ) - node_query = ( - f"""CALL db.upsertVertex("{self._vertex_type}", [{parser(node_list)}])""" - ) - for edge in edges: - edge_list.append( - { - "sid": escape_quotes(edge.sid), - "tid": escape_quotes(edge.tid), - "id": escape_quotes(edge.name), - "name": escape_quotes(edge.name), - "description": escape_quotes(edge.get_prop("description")), - } - ) - - edge_query = f"""CALL db.upsertEdge( - "{self._edge_type}", - {{type:"{self._vertex_type}", key:"sid"}}, - {{type:"{self._vertex_type}", key:"tid"}}, - [{parser(edge_list)}])""" - self.conn.run(query=node_query) - self.conn.run(query=edge_query) - - def truncate(self): - """Truncate Graph.""" - gql = "MATCH (n) DELETE n" - self.conn.run(gql) - - def drop(self): - """Delete Graph.""" - self.conn.delete_graph(self._graph_name) - - def delete_triplet(self, sub: str, rel: str, obj: str) -> None: - """Delete triplet.""" - del_query = ( - f"MATCH (n1:{self._vertex_type} {{id:'{sub}'}})" - f"-[r:{self._edge_type} {{id:'{rel}'}}]->" - f"(n2:{self._vertex_type} {{id:'{obj}'}}) DELETE n1,n2,r" - ) - self.conn.run(query=del_query) - - def get_schema(self, refresh: bool = False) -> str: - """Get the schema of the graph store.""" - query = "CALL dbms.graph.getGraphSchema()" - data = self.conn.run(query=query) - schema = data[0]["schema"] - return schema - - def get_full_graph(self, limit: Optional[int] = None) -> Graph: - """Get full graph.""" - if not limit: - raise Exception("limit must be set") - graph_result = self.query( - f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}", - white_list=["_community_id"], - ) - all_graph = MemoryGraph() - for vertex in graph_result.vertices(): - all_graph.upsert_vertex(vertex) - for edge in graph_result.edges(): - all_graph.append_edge(edge) - return all_graph - - def explore( - self, - subs: List[str], - direct: Direction = Direction.BOTH, - depth: Optional[int] = None, - fan: Optional[int] = None, - limit: Optional[int] = None, - ) -> Graph: - """Explore the graph from given subjects up to a depth.""" - if not subs: - return MemoryGraph() - - if fan is not None: - raise ValueError("Fan functionality is not supported at this time.") - else: - depth_string = f"1..{depth}" - if depth is None: - depth_string = ".." - - limit_string = f"LIMIT {limit}" - if limit is None: - limit_string = "" - if direct.name == "OUT": - rel = f"-[r:{self._edge_type}*{depth_string}]->" - elif direct.name == "IN": - rel = f"<-[r:{self._edge_type}*{depth_string}]-" - else: - rel = f"-[r:{self._edge_type}*{depth_string}]-" - query = ( - f"MATCH p=(n:{self._vertex_type})" - f"{rel}(m:{self._vertex_type}) " - f"WHERE n.id IN {subs} RETURN p {limit_string}" - ) - return self.query(query) - - def query(self, query: str, **args) -> MemoryGraph: - """Execute a query on graph.""" - result = self.conn.run(query=query) - white_list = args.get("white_list", []) - graph = self._format_query_data(result, white_list) - mg = MemoryGraph() - for vertex in graph["nodes"]: - mg.upsert_vertex(vertex) - for edge in graph["edges"]: - mg.append_edge(edge) - return mg - - def stream_query(self, query: str) -> Generator[Graph, None, None]: - """Execute a stream query.""" - from neo4j import graph - - for record in self.conn.run_stream(query): - mg = MemoryGraph() - for key in record.keys(): - value = record[key] - if isinstance(value, graph.Node): - node_id = value._properties["id"] - description = value._properties["description"] - vertex = Vertex(node_id, name=node_id, description=description) - mg.upsert_vertex(vertex) - elif isinstance(value, graph.Relationship): - rel_nodes = value.nodes - prop_id = value._properties["id"] - src_id = rel_nodes[0]._properties["id"] - dst_id = rel_nodes[1]._properties["id"] - description = value._properties["description"] - edge = Edge(src_id, dst_id, name=prop_id, description=description) - mg.append_edge(edge) - elif isinstance(value, graph.Path): - nodes = list(record["p"].nodes) - rels = list(record["p"].relationships) - formatted_path = [] - for i in range(len(nodes)): - formatted_path.append( - { - "id": nodes[i]._properties["id"], - "description": nodes[i]._properties["description"], - } - ) - if i < len(rels): - formatted_path.append( - { - "id": rels[i]._properties["id"], - "description": rels[i]._properties["description"], - } - ) - for i in range(0, len(formatted_path), 2): - mg.upsert_vertex( - Vertex( - formatted_path[i]["id"], - name=formatted_path[i]["id"], - description=formatted_path[i]["description"], - ) - ) - if i + 2 < len(formatted_path): - mg.append_edge( - Edge( - formatted_path[i]["id"], - formatted_path[i + 2]["id"], - name=formatted_path[i + 1]["id"], - description=formatted_path[i + 1]["description"], - ) - ) - else: - vertex = Vertex("json_node", name="json_node", description=value) - mg.upsert_vertex(vertex) - yield mg + def _escape_quotes(self, value: str) -> str: + """Escape single and double quotes in a string for queries.""" + if value is not None: + return value.replace("'", "").replace('"', "") diff --git a/dbgpt/storage/knowledge_graph/base.py b/dbgpt/storage/knowledge_graph/base.py index e47094bba..8080cbd0f 100644 --- a/dbgpt/storage/knowledge_graph/base.py +++ b/dbgpt/storage/knowledge_graph/base.py @@ -1,4 +1,5 @@ """Knowledge graph base class.""" + import logging from abc import ABC, abstractmethod from typing import List, Optional @@ -27,6 +28,6 @@ def get_config(self) -> KnowledgeGraphConfig: def query_graph(self, limit: Optional[int] = None) -> Graph: """Get graph data.""" + @abstractmethod def delete_by_ids(self, ids: str) -> List[str]: """Delete document by ids.""" - raise Exception("Delete document not supported by knowledge graph") diff --git a/dbgpt/storage/knowledge_graph/community/base.py b/dbgpt/storage/knowledge_graph/community/base.py index 9dcb17bc1..6679c0448 100644 --- a/dbgpt/storage/knowledge_graph/community/base.py +++ b/dbgpt/storage/knowledge_graph/community/base.py @@ -1,11 +1,19 @@ """Define Classes about Community.""" + import logging from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import AsyncGenerator, Iterator, List, Optional from dbgpt.storage.graph_store.base import GraphStoreBase -from dbgpt.storage.graph_store.graph import Graph +from dbgpt.storage.graph_store.graph import ( + Direction, + Edge, + Graph, + GraphElemType, + MemoryGraph, + Vertex, +) logger = logging.getLogger(__name__) @@ -24,7 +32,7 @@ class CommunityTree: """Represents a community tree.""" -class CommunityStoreAdapter(ABC): +class GraphStoreAdapter(ABC): """Community Store Adapter.""" def __init__(self, graph_store: GraphStoreBase): @@ -44,6 +52,113 @@ async def discover_communities(self, **kwargs) -> List[str]: async def get_community(self, community_id: str) -> Community: """Get community.""" + @abstractmethod + def get_graph_config(self): + """Get config.""" + + @abstractmethod + def get_vertex_type(self) -> str: + """Get vertex type.""" + + @abstractmethod + def get_edge_type(self) -> str: + """Get edge type.""" + + @abstractmethod + def get_triplets(self, sub: str) -> List[tuple[str, str]]: + """Get triplets.""" + + @abstractmethod + def get_document_vertex(self, doc_name: str) -> Vertex: + """Get document vertex.""" + + @abstractmethod + def get_schema(self, refresh: bool = False) -> str: + """Get schema.""" + + @abstractmethod + def get_full_graph(self, limit: Optional[int] = None) -> Graph: + """Get full graph.""" + + @abstractmethod + def upsert_entities(self, entities: Iterator[Vertex]) -> None: + """Upsert entity.""" + + @abstractmethod + def upsert_edge( + self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str + ): + """Upsert edge.""" + + @abstractmethod + def upsert_chunks(self, chunk: Iterator[Vertex]) -> None: + """Upsert chunk.""" + + @abstractmethod + def upsert_documents(self, documents: Iterator[Vertex]) -> None: + """Upsert documents.""" + + @abstractmethod + def upsert_relations(self, relations: Iterator[Edge]) -> None: + """Upsert relations.""" + + @abstractmethod + def insert_triplet(self, sub: str, rel: str, obj: str) -> None: + """Insert triplet.""" + + @abstractmethod + def upsert_graph(self, graph: Graph) -> None: + """Insert graph.""" + + @abstractmethod + def delete_document(self, chunk_id: str) -> None: + """Delete document in graph store.""" + + @abstractmethod + def delete_triplet(self, sub: str, rel: str, obj: str) -> None: + """Delete triplet.""" + + @abstractmethod + def drop(self) -> None: + """Drop graph.""" + + @abstractmethod + def create_graph(self, graph_name: str) -> None: + """Create graph.""" + + @abstractmethod + def create_graph_label(self) -> None: + """Create a graph label. + + The graph label is used to identify and distinguish different types of nodes + (vertices) and edges in the graph. + """ + + @abstractmethod + def truncate(self) -> None: + """Truncate graph.""" + + @abstractmethod + def check_label(self, graph_elem_type: GraphElemType) -> bool: + """Check if the label exists in the graph.""" + + @abstractmethod + def explore( + self, + subs: List[str], + direct: Direction = Direction.BOTH, + depth: Optional[int] = None, + ) -> MemoryGraph: + """Explore the graph from given subjects up to a depth.""" + + @abstractmethod + def query(self, query: str, **kwargs) -> MemoryGraph: + """Execute a query on graph.""" + + @abstractmethod + async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]: + """Execute a stream query.""" + class CommunityMetastore(ABC): """Community metastore class.""" diff --git a/dbgpt/storage/knowledge_graph/community/community_store.py b/dbgpt/storage/knowledge_graph/community/community_store.py index 41bb494a3..3a5eb2474 100644 --- a/dbgpt/storage/knowledge_graph/community/community_store.py +++ b/dbgpt/storage/knowledge_graph/community/community_store.py @@ -4,10 +4,7 @@ from typing import List from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer -from dbgpt.storage.knowledge_graph.community.base import ( - Community, - CommunityStoreAdapter, -) +from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter from dbgpt.storage.knowledge_graph.community.community_metastore import ( BuiltinCommunityMetastore, ) @@ -21,23 +18,23 @@ class CommunityStore: def __init__( self, - community_store_adapter: CommunityStoreAdapter, + graph_store_adapter: GraphStoreAdapter, community_summarizer: CommunitySummarizer, vector_store: VectorStoreBase, ): """Initialize the CommunityStore class.""" - self._community_store_adapter = community_store_adapter + self._graph_store_adapter = graph_store_adapter self._community_summarizer = community_summarizer self._meta_store = BuiltinCommunityMetastore(vector_store) async def build_communities(self): """Discover communities.""" - community_ids = await self._community_store_adapter.discover_communities() + community_ids = await self._graph_store_adapter.discover_communities() # summarize communities communities = [] for community_id in community_ids: - community = await self._community_store_adapter.get_community(community_id) + community = await self._graph_store_adapter.get_community(community_id) graph = community.data.format() if not graph: break @@ -65,7 +62,7 @@ def truncate(self): self._community_summarizer.truncate() logger.info("Truncate graph") - self._community_store_adapter.graph_store.truncate() + self._graph_store_adapter.truncate() def drop(self): """Drop community store.""" @@ -76,4 +73,4 @@ def drop(self): self._community_summarizer.drop() logger.info("Remove graph") - self._community_store_adapter.graph_store.drop() + self._graph_store_adapter.drop() diff --git a/dbgpt/storage/knowledge_graph/community/factory.py b/dbgpt/storage/knowledge_graph/community/factory.py index 4bafa74cd..e238935ae 100644 --- a/dbgpt/storage/knowledge_graph/community/factory.py +++ b/dbgpt/storage/knowledge_graph/community/factory.py @@ -1,28 +1,29 @@ -"""CommunityStoreAdapter factory.""" +"""GraphStoreAdapter factory.""" + import logging from dbgpt.storage.graph_store.base import GraphStoreBase from dbgpt.storage.graph_store.tugraph_store import TuGraphStore -from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter -from dbgpt.storage.knowledge_graph.community.tugraph_adapter import ( - TuGraphCommunityStoreAdapter, +from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter +from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import ( + TuGraphStoreAdapter, ) logger = logging.getLogger(__name__) -class CommunityStoreAdapterFactory: +class GraphStoreAdapterFactory: """Factory for community store adapter.""" @staticmethod - def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter: - """Create a CommunityStoreAdapter instance. + def create(graph_store: GraphStoreBase) -> GraphStoreAdapter: + """Create a GraphStoreAdapter instance. Args: - graph_store_type: graph store type Memory, TuGraph, Neo4j """ if isinstance(graph_store, TuGraphStore): - return TuGraphCommunityStoreAdapter(graph_store) + return TuGraphStoreAdapter(graph_store) else: raise Exception( "create community store adapter for %s failed", diff --git a/dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py b/dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py new file mode 100644 index 000000000..d26b9f519 --- /dev/null +++ b/dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py @@ -0,0 +1,186 @@ +"""TuGraph Community Store Adapter.""" + +import json +import logging +from typing import AsyncGenerator, Iterator, List, Optional, Tuple + +from dbgpt.storage.graph_store.graph import ( + Direction, + Edge, + Graph, + GraphElemType, + MemoryGraph, + Vertex, +) +from dbgpt.storage.graph_store.memgraph_store import ( + MemoryGraphStore, + MemoryGraphStoreConfig, +) +from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter + +logger = logging.getLogger(__name__) + + +class MemGraphStoreAdapter(GraphStoreAdapter): + """MemGraph Community Store Adapter.""" + + MAX_HIERARCHY_LEVEL = 3 + + def __init__(self, enable_summary: bool = False): + """Initialize MemGraph Community Store Adapter.""" + self._graph_store = MemoryGraphStore(MemoryGraphStoreConfig()) + self._enable_summary = enable_summary + + super().__init__(self._graph_store) + + # Create the graph + self.create_graph(self._graph_store.get_config().name) + + async def discover_communities(self, **kwargs) -> List[str]: + """Run community discovery with leiden.""" + pass + + async def get_community(self, community_id: str) -> Community: + """Get community.""" + pass + + def get_graph_config(self): + """Get the graph store config.""" + return self._graph_store.get_config() + + def get_vertex_type(self) -> str: + """Get the vertex type.""" + # raise NotImplementedError("Memory graph store does not have vertex type") + return "" + + def get_edge_type(self) -> str: + """Get the edge type.""" + # raise NotImplementedError("Memory graph store does not have edge type") + return "" + + def get_triplets(self, subj: str) -> List[Tuple[str, str]]: + """Get triplets.""" + subgraph = self.explore([subj], direct=Direction.OUT, depth=1) + return [(e.name, e.tid) for e in subgraph.edges()] + + def get_document_vertex(self, doc_name: str) -> Vertex: + """Get the document vertex in the graph.""" + raise NotImplementedError("Memory graph store does not have document vertex") + + def get_schema(self, refresh: bool = False) -> str: + """Get the schema of the graph store.""" + return json.dumps(self._graph_store._graph.schema()) + + def get_full_graph(self, limit: Optional[int] = None) -> Graph: + """Get full graph.""" + if not limit: + return self._graph_store._graph + + subgraph = MemoryGraph() + for count, edge in enumerate(self._graph_store._graph.edges()): + if count >= limit: + break + subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.sid)) + subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.tid)) + subgraph.append_edge(edge) + count += 1 + return subgraph + + def upsert_entities(self, entities: Iterator[Vertex]) -> None: + """Upsert entities.""" + pass + + def upsert_edge( + self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str + ) -> None: + """Upsert edges.""" + pass + + def upsert_chunks(self, chunks: Iterator[Vertex]) -> None: + """Upsert chunks.""" + pass + + def upsert_documents(self, documents: Iterator[Vertex]) -> None: + """Upsert documents.""" + pass + + def upsert_relations(self, relations: Iterator[Edge]) -> None: + """Upsert relations.""" + pass + + def insert_triplet(self, subj: str, rel: str, obj: str) -> None: + """Add triplet.""" + self._graph_store._graph.append_edge(Edge(subj, obj, rel)) + + def upsert_graph(self, graph: Graph) -> None: + """Add graph to the graph store. + + Args: + graph (Graph): The graph to be added. + """ + for vertex in graph.vertices(): + self._graph_store._graph.upsert_vertex(vertex) + + for edge in graph.edges(): + self._graph_store._graph.append_edge(edge) + + def delete_document(self, chunk_ids: str) -> None: + """Delete document in the graph.""" + pass + + def delete_triplet(self, sub: str, rel: str, obj: str) -> None: + """Delete triplet.""" + self._graph_store._graph.del_edges(sub, obj, rel) + + def drop(self): + """Delete Graph.""" + self._graph_store._graph = None + + def create_graph(self, graph_name: str): + """Create a graph.""" + pass + + def create_graph_label( + self, + ) -> None: + """Create a graph label. + + The graph label is used to identify and distinguish different types of nodes + (vertices) and edges in the graph. + """ + pass + + def truncate(self): + """Truncate Graph.""" + self._graph_store._graph.truncate() + + def check_label(self, graph_elem_type: GraphElemType) -> bool: + """Check if the label exists in the graph. + + Args: + graph_elem_type (GraphElemType): The type of the graph element. + + Returns: + True if the label exists in the specified graph element type, otherwise + False. + """ + pass + + def explore( + self, + subs: List[str], + direct: Direction = Direction.BOTH, + depth: int | None = None, + fan: int | None = None, + limit: int | None = None, + ) -> MemoryGraph: + """Explore the graph from given subjects up to a depth.""" + return self._graph_store._graph.search(subs, direct, depth, fan, limit) + + def query(self, query: str, **kwargs) -> MemoryGraph: + """Execute a query on graph.""" + pass + + async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]: + """Execute a stream query.""" + pass diff --git a/dbgpt/storage/knowledge_graph/community/tugraph_adapter.py b/dbgpt/storage/knowledge_graph/community/tugraph_adapter.py deleted file mode 100644 index 9dcbbe046..000000000 --- a/dbgpt/storage/knowledge_graph/community/tugraph_adapter.py +++ /dev/null @@ -1,52 +0,0 @@ -"""TuGraph Community Store Adapter.""" -import json -import logging -from typing import List - -from dbgpt.storage.graph_store.graph import MemoryGraph -from dbgpt.storage.knowledge_graph.community.base import ( - Community, - CommunityStoreAdapter, -) - -logger = logging.getLogger(__name__) - - -class TuGraphCommunityStoreAdapter(CommunityStoreAdapter): - """TuGraph Community Store Adapter.""" - - MAX_HIERARCHY_LEVEL = 3 - - async def discover_communities(self, **kwargs) -> List[str]: - """Run community discovery with leiden.""" - mg = self._graph_store.query( - "CALL db.plugin.callPlugin" - "('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)" - ) - result = mg.get_vertex("json_node").get_prop("description") - community_ids = json.loads(result)["community_id_list"] - logger.info(f"Discovered {len(community_ids)} communities.") - return community_ids - - async def get_community(self, community_id: str) -> Community: - """Get community.""" - query = ( - f"MATCH (n:{self._graph_store.get_vertex_type()})" - f"WHERE n._community_id = '{community_id}' RETURN n" - ) - edge_query = ( - f"MATCH (n:{self._graph_store.get_vertex_type()})-" - f"[r:{self._graph_store.get_edge_type()}]-" - f"(m:{self._graph_store.get_vertex_type()})" - f"WHERE n._community_id = '{community_id}' RETURN n,r,m" - ) - - all_vertex_graph = self._graph_store.aquery(query) - all_edge_graph = self._graph_store.aquery(edge_query) - all_graph = MemoryGraph() - for vertex in all_vertex_graph.vertices(): - all_graph.upsert_vertex(vertex) - for edge in all_edge_graph.edges(): - all_graph.append_edge(edge) - - return Community(id=community_id, data=all_graph) diff --git a/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py b/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py new file mode 100644 index 000000000..ae1d587a0 --- /dev/null +++ b/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py @@ -0,0 +1,808 @@ +"""TuGraph Community Store Adapter.""" + +import json +import logging +from typing import ( + Any, + AsyncGenerator, + Dict, + Iterator, + List, + Literal, + Optional, + Tuple, + Union, +) + +from dbgpt.storage.graph_store.graph import ( + Direction, + Edge, + Graph, + GraphElemType, + MemoryGraph, + Vertex, +) +from dbgpt.storage.graph_store.tugraph_store import TuGraphStore +from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter + +logger = logging.getLogger(__name__) + + +class TuGraphStoreAdapter(GraphStoreAdapter): + """TuGraph Community Store Adapter.""" + + MAX_QUERY_LIMIT = 1000 + MAX_HIERARCHY_LEVEL = 3 + + def __init__(self, graph_store: TuGraphStore): + """Initialize TuGraph Community Store Adapter.""" + super().__init__(graph_store) + + # Create the graph + self.create_graph(self.graph_store.get_config().name) + + async def discover_communities(self, **kwargs) -> List[str]: + """Run community discovery with leiden.""" + mg = self.query( + "CALL db.plugin.callPlugin('CPP'," + "'leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)" + ) + result = mg.get_vertex("json_node").get_prop("description") + community_ids = json.loads(result)["community_id_list"] + logger.info(f"Discovered {len(community_ids)} communities.") + return community_ids + + async def get_community(self, community_id: str) -> Community: + """Get community.""" + query = ( + f"MATCH (n:{self.get_vertex_type()}) WHERE n._community_id = " + f"'{community_id}' RETURN n" + ) + edge_query = ( + f"MATCH (n:{self.get_vertex_type()})-" + f"[r:{self.get_edge_type()}]-" + f"(m:{self.get_vertex_type()})" + f"WHERE n._community_id = '{community_id}' RETURN n,r,m" + ) + + all_vertex_graph = self.query(query) + all_edge_graph = self.query(edge_query) + all_graph = MemoryGraph() + for vertex in all_vertex_graph.vertices(): + all_graph.upsert_vertex(vertex) + for edge in all_edge_graph.edges(): + all_graph.append_edge(edge) + + return Community(id=community_id, data=all_graph) + + @property + def graph_store(self) -> TuGraphStore: + """Get the graph store.""" + return self._graph_store + + def get_graph_config(self): + """Get the graph store config.""" + return self.graph_store.get_config() + + def get_vertex_type(self) -> str: + """Get the vertex type.""" + return GraphElemType.ENTITY.value + + def get_edge_type(self) -> str: + """Get the edge type.""" + return GraphElemType.RELATION.value + + def get_triplets(self, subj: str) -> List[Tuple[str, str]]: + """Get triplets.""" + triplet_query = ( + f"MATCH (n1:{GraphElemType.ENTITY.value})-[r]->(n2:" + f"{GraphElemType.ENTITY.value}) " + f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;' + ) + data = self.graph_store.conn.run(triplet_query) + return [(record["rel"], record["obj"]) for record in data] + + def get_document_vertex(self, doc_name: str) -> Vertex: + """Get the document vertex in the graph.""" + gql = f"""MATCH (n) WHERE n.id = {doc_name} RETURN n""" + graph = self.query(gql) + vertex = graph.get_vertex(doc_name) + return vertex + + def get_schema(self, refresh: bool = False) -> str: + """Get the schema of the graph store.""" + query = "CALL dbms.graph.getGraphSchema()" + data = self.graph_store.conn.run(query=query) + schema = data[0]["schema"] + return schema + + def get_full_graph(self, limit: Optional[int] = None) -> Graph: + """Get full graph.""" + if not limit: + limit = self.MAX_QUERY_LIMIT + if limit <= 0: + raise ValueError("Limit must be greater than 0.") + graph_result = self.query( + f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}", + white_list=["_community_id"], + ) + full_graph = MemoryGraph() + for vertex in graph_result.vertices(): + full_graph.upsert_vertex(vertex) + for edge in graph_result.edges(): + full_graph.append_edge(edge) + return full_graph + + def upsert_entities(self, entities: Iterator[Vertex]) -> None: + """Upsert entities.""" + entity_list = [ + { + "id": self.graph_store._escape_quotes(entity.vid), + "name": self.graph_store._escape_quotes(entity.name), + "description": self.graph_store._escape_quotes( + entity.get_prop("description") + ) + or "", + "_document_id": "0", + "_chunk_id": "0", + "_community_id": "0", + } + for entity in entities + ] + entity_query = ( + f"CALL db.upsertVertex(" + f'"{GraphElemType.ENTITY.value}", ' + f"[{self._parser(entity_list)}])" + ) + self.graph_store.conn.run(query=entity_query) + + def upsert_edge( + self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str + ) -> None: + """Upsert edges.""" + edge_list = [ + { + "sid": self.graph_store._escape_quotes(edge.sid), + "tid": self.graph_store._escape_quotes(edge.tid), + "id": self.graph_store._escape_quotes(edge.name), + "name": self.graph_store._escape_quotes(edge.name), + "description": self.graph_store._escape_quotes( + edge.get_prop("description") + ) + or "", + "_chunk_id": self.graph_store._escape_quotes(edge.get_prop("_chunk_id")) + or "", + } + for edge in edges + ] + relation_query = f"""CALL db.upsertEdge("{edge_type}", + {{type:"{src_type}", key:"sid"}}, + {{type:"{dst_type}", key:"tid"}}, + [{self._parser(edge_list)}])""" + self.graph_store.conn.run(query=relation_query) + + def upsert_chunks(self, chunks: Iterator[Vertex]) -> None: + """Upsert chunks.""" + chunk_list = [ + { + "id": self.graph_store._escape_quotes(chunk.vid), + "name": self.graph_store._escape_quotes(chunk.name), + "content": self.graph_store._escape_quotes(chunk.get_prop("content")), + } + for chunk in chunks + ] + chunk_query = ( + f"CALL db.upsertVertex(" + f'"{GraphElemType.CHUNK.value}", ' + f"[{self._parser(chunk_list)}])" + ) + self.graph_store.conn.run(query=chunk_query) + + def upsert_documents(self, documents: Iterator[Vertex]) -> None: + """Upsert documents.""" + document_list = [ + { + "id": self.graph_store._escape_quotes(document.vid), + "name": self.graph_store._escape_quotes(document.name), + "content": self.graph_store._escape_quotes(document.get_prop("content")) + or "", + } + for document in documents + ] + document_query = ( + "CALL db.upsertVertex(" + f'"{GraphElemType.DOCUMENT.value}", ' + f"[{self._parser(document_list)}])" + ) + self.graph_store.conn.run(query=document_query) + + def upsert_relations(self, relations: Iterator[Edge]) -> None: + """Upsert relations.""" + pass + + def insert_triplet(self, subj: str, rel: str, obj: str) -> None: + """Add triplet.""" + subj_escaped = subj.replace("'", "\\'").replace('"', '\\"') + rel_escaped = rel.replace("'", "\\'").replace('"', '\\"') + obj_escaped = obj.replace("'", "\\'").replace('"', '\\"') + + vertex_query = f"""CALL db.upsertVertex( + '{GraphElemType.ENTITY.value}', + [{{id:'{subj_escaped}',name:'{subj_escaped}'}}, + {{id:'{obj_escaped}',name:'{obj_escaped}'}}])""" + edge_query = f"""CALL db.upsertEdge( + '{GraphElemType.RELATION.value}', + {{type:"{GraphElemType.ENTITY.value}",key:"sid"}}, + {{type:"{GraphElemType.ENTITY.value}", key:"tid"}}, + [{{sid:"{subj_escaped}", + tid: "{obj_escaped}", + id:"{rel_escaped}", + name: "{rel_escaped}"}}])""" + + self.graph_store.conn.run(query=vertex_query) + self.graph_store.conn.run(query=edge_query) + + def upsert_graph(self, graph: MemoryGraph) -> None: + """Add graph to the graph store. + + Args: + graph (Graph): The graph to be added. + """ + # Get the iterators of all the vertices and the edges from the graph + documents: Iterator[Vertex] = graph.vertices( + filter_fn=lambda x: x.get_prop("vertex_type") + == GraphElemType.DOCUMENT.value + ) + chunks: Iterator[Vertex] = graph.vertices( + filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.CHUNK.value + ) + entities: Iterator[Vertex] = graph.vertices( + filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.ENTITY.value + ) + doc_include_chunk: Iterator[Edge] = graph.edges( + filter_fn=lambda x: x.get_prop("edge_type") + == GraphElemType.DOCUMENT_INCLUDE_CHUNK.value + ) + chunk_include_chunk: Iterator[Edge] = graph.edges( + filter_fn=lambda x: x.get_prop("edge_type") + == GraphElemType.CHUNK_INCLUDE_CHUNK.value + ) + chunk_include_entity: Iterator[Edge] = graph.edges( + filter_fn=lambda x: x.get_prop("edge_type") + == GraphElemType.CHUNK_INCLUDE_ENTITY.value + ) + chunk_next_chunk: Iterator[Edge] = graph.edges( + filter_fn=lambda x: x.get_prop("edge_type") + == GraphElemType.CHUNK_NEXT_CHUNK.value + ) + relation: Iterator[Edge] = graph.edges( + filter_fn=lambda x: x.get_prop("edge_type") == GraphElemType.RELATION.value + ) + + # Upsert the vertices and the edges to the graph store + self.upsert_entities(entities) + self.upsert_chunks(chunks) + self.upsert_documents(documents) + self.upsert_edge( + doc_include_chunk, + GraphElemType.INCLUDE.value, + GraphElemType.DOCUMENT.value, + GraphElemType.CHUNK.value, + ) + self.upsert_edge( + chunk_include_chunk, + GraphElemType.INCLUDE.value, + GraphElemType.CHUNK.value, + GraphElemType.CHUNK.value, + ) + self.upsert_edge( + chunk_include_entity, + GraphElemType.INCLUDE.value, + GraphElemType.CHUNK.value, + GraphElemType.ENTITY.value, + ) + self.upsert_edge( + chunk_next_chunk, + GraphElemType.NEXT.value, + GraphElemType.CHUNK.value, + GraphElemType.CHUNK.value, + ) + self.upsert_edge( + relation, + GraphElemType.RELATION.value, + GraphElemType.ENTITY.value, + GraphElemType.ENTITY.value, + ) + + def delete_document(self, chunk_ids: str) -> None: + """Delete document in the graph.""" + chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")] + del_chunk_gql = ( + f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->" + f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n" + ) + del_relation_gql = ( + f"MATCH(m:{GraphElemType.ENTITY.value})-[r:" + f"{GraphElemType.RELATION.value}]-(n:{GraphElemType.ENTITY.value}) " + f"WHERE r._chunk_id IN {chunkids_list} DELETE r" + ) + delete_only_vertex = "MATCH (n) WHERE NOT EXISTS((n)-[]-()) DELETE n" + self.graph_store.conn.run(del_chunk_gql) + self.graph_store.conn.run(del_relation_gql) + self.graph_store.conn.run(delete_only_vertex) + + def delete_triplet(self, sub: str, rel: str, obj: str) -> None: + """Delete triplet.""" + del_query = ( + f"MATCH (n1:{GraphElemType.ENTITY.value} {{id:'{sub}'}})" + f"-[r:{GraphElemType.RELATION.value} {{id:'{rel}'}}]->" + f"(n2:{GraphElemType.ENTITY.value} {{id:'{obj}'}}) DELETE n1,n2,r" + ) + self.graph_store.conn.run(query=del_query) + + def drop(self): + """Delete Graph.""" + self.graph_store.conn.delete_graph(self.get_graph_config().name) + + def create_graph(self, graph_name: str): + """Create a graph.""" + self.graph_store.conn.create_graph(graph_name=graph_name) + + # Create the graph schema + def _format_graph_propertity_schema( + name: str, + type: str = "STRING", + optional: bool = False, + index: Optional[bool] = None, + **kwargs, + ) -> Dict[str, str | bool]: + """Format the property for TuGraph. + + Args: + name: The name of the property. + type: The type of the property. + optional: The optional of the property. + index: The index of the property. + kwargs: Additional keyword arguments. + + Returns: + The formatted property. + """ + property: Dict[str, str | bool] = { + "name": name, + "type": type, + "optional": optional, + } + + if index is not None: + property["index"] = index + + # Add any additional keyword arguments to the property dictionary + property.update(kwargs) + return property + + # Create the graph label for document vertex + document_proerties: List[Dict[str, Union[str, bool]]] = [ + _format_graph_propertity_schema("id", "STRING", False), + _format_graph_propertity_schema("name", "STRING", False), + _format_graph_propertity_schema("_community_id", "STRING", True, True), + ] + self.create_graph_label( + graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties + ) + + # Create the graph label for chunk vertex + chunk_proerties: List[Dict[str, Union[str, bool]]] = [ + _format_graph_propertity_schema("id", "STRING", False), + _format_graph_propertity_schema("name", "STRING", False), + _format_graph_propertity_schema("_community_id", "STRING", True, True), + _format_graph_propertity_schema("content", "STRING", True, True), + ] + self.create_graph_label( + graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties + ) + + # Create the graph label for entity vertex + vertex_proerties: List[Dict[str, Union[str, bool]]] = [ + _format_graph_propertity_schema("id", "STRING", False), + _format_graph_propertity_schema("name", "STRING", False), + _format_graph_propertity_schema("_community_id", "STRING", True, True), + _format_graph_propertity_schema("description", "STRING", True, True), + ] + self.create_graph_label( + graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties + ) + + # Create the graph label for relation edge + edge_proerties: List[Dict[str, Union[str, bool]]] = [ + _format_graph_propertity_schema("id", "STRING", False), + _format_graph_propertity_schema("name", "STRING", False), + _format_graph_propertity_schema("_chunk_id", "STRING", True, True), + _format_graph_propertity_schema("description", "STRING", True, True), + ] + self.create_graph_label( + graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties + ) + + # Create the graph label for include edge + include_proerties: List[Dict[str, Union[str, bool]]] = [ + _format_graph_propertity_schema("id", "STRING", False), + _format_graph_propertity_schema("name", "STRING", False), + _format_graph_propertity_schema("description", "STRING", True), + ] + self.create_graph_label( + graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties + ) + + # Create the graph label for next edge + next_proerties: List[Dict[str, Union[str, bool]]] = [ + _format_graph_propertity_schema("id", "STRING", False), + _format_graph_propertity_schema("name", "STRING", False), + _format_graph_propertity_schema("description", "STRING", True), + ] + self.create_graph_label( + graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties + ) + + if self.graph_store._enable_summary: + self.graph_store._upload_plugin() + + def create_graph_label( + self, + graph_elem_type: GraphElemType, + graph_properties: List[Dict[str, Union[str, bool]]], + ) -> None: + """Create a graph label. + + The graph label is used to identify and distinguish different types of nodes + (vertices) and edges in the graph. + """ + if graph_elem_type.is_vertex(): # vertex + data = json.dumps({ + "label": graph_elem_type.value, + "type": "VERTEX", + "primary": "id", + "properties": graph_properties, + }) + gql = f"""CALL db.createVertexLabelByJson('{data}')""" + + gql_check_exist = ( + f"""CALL db.getLabelSchema('VERTEX', '{graph_elem_type.value}')""" + ) + else: # edge + + def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]: + """Define the edge direction. + + `include` edge: document -> chunk, chunk -> entity + `next` edge: chunk -> chunk + `relation` edge: entity -> entity + """ + if graph_elem_type.is_vertex(): + raise ValueError("The graph element type must be an edge.") + if graph_elem_type == GraphElemType.INCLUDE: + return [ + [GraphElemType.DOCUMENT.value, GraphElemType.CHUNK.value], + [GraphElemType.CHUNK.value, GraphElemType.ENTITY.value], + [GraphElemType.CHUNK.value, GraphElemType.CHUNK.value], + ] + elif graph_elem_type == GraphElemType.NEXT: + return [[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value]] + elif graph_elem_type == GraphElemType.RELATION: + return [[GraphElemType.ENTITY.value, GraphElemType.ENTITY.value]] + else: + raise ValueError("Invalid graph element type.") + + data = json.dumps({ + "label": graph_elem_type.value, + "type": "EDGE", + "constraints": edge_direction(graph_elem_type), + "properties": graph_properties, + }) + gql = f"""CALL db.createEdgeLabelByJson('{data}')""" + + gql_check_exist = ( + f"""CALL db.getLabelSchema('EDGE', '{graph_elem_type.value}')""" + ) + + # Make sure the graph label is identical + try: + self.graph_store.conn.run( + gql_check_exist + ) # if not exist, qurying raises an exception + except Exception: + self.graph_store.conn.run(gql) # create the graph label + return + + logger.info(f"Graph label {graph_elem_type.value} already exists.") + + def truncate(self): + """Truncate Graph.""" + gql = "MATCH (n) DELETE n" + self.graph_store.conn.run(gql) + + def check_label(self, graph_elem_type: GraphElemType) -> bool: + """Check if the label exists in the graph. + + Args: + graph_elem_type (GraphElemType): The type of the graph element. + + Returns: + True if the label exists in the specified graph element type, otherwise + False. + """ + vertex_tables, edge_tables = self.graph_store.conn.get_table_names() + + if graph_elem_type.is_vertex(): + return graph_elem_type in vertex_tables + else: + return graph_elem_type in edge_tables + + def explore( + self, + subs: List[str], + direct: Direction = Direction.BOTH, + depth: Optional[int] = None, + limit: Optional[int] = None, + search_scope: Optional[ + Literal["knowledge_graph", "document_graph"] + ] = "knowledge_graph", + ) -> MemoryGraph: + """Explore the graph from given subjects up to a depth.""" + if not subs: + return MemoryGraph() + + if depth is None or depth < 0 or depth > self.MAX_HIERARCHY_LEVEL: + # TODO: to be discussed, be none or MAX_HIERARCHY_LEVEL + # depth_string = ".." + depth = self.MAX_HIERARCHY_LEVEL + depth_string = f"1..{depth}" + + if limit is None: + limit_string = "" + else: + limit_string = f"LIMIT {limit}" + + if search_scope == "knowledge_graph": + if direct.name == "OUT": + rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]->" + elif direct.name == "IN": + rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-" + else: + rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-" + query = ( + f"MATCH p=(n:{GraphElemType.ENTITY.value})" + f"{rel}(m:{GraphElemType.ENTITY.value}) " + f"WHERE n.id IN {subs} RETURN p {limit_string}" + ) + return self.query(query) + else: + graph = MemoryGraph() + + for sub in subs: + query = ( + f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-" + f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-" + f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS '{sub}' " + f"RETURN p {limit_string}" + ) # if it contains the subjects + result = self.query(query) + for vertex in result.vertices(): + graph.upsert_vertex(vertex) + for edge in result.edges(): + graph.append_edge(edge) + + return graph + + def query(self, query: str, **kwargs) -> MemoryGraph: + """Execute a query on graph. + + white_list: List[str] = kwargs.get("white_list", []), which contains the white + list of properties and filters the properties that are not in the white list. + """ + query_result = self.graph_store.conn.run(query=query) + white_list: List[str] = kwargs.get( + "white_list", + [ + "id", + "name", + "description", + "_document_id", + "_chunk_id", + "_community_id", + ], + ) + vertices, edges = self._get_nodes_edges_from_queried_data( + query_result, white_list + ) + mg = MemoryGraph() + for vertex in vertices: + mg.upsert_vertex(vertex) + for edge in edges: + mg.append_edge(edge) + return mg + + async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]: + """Execute a stream query.""" + from neo4j import graph + + async for record in self.graph_store.conn.run_stream(query): + mg = MemoryGraph() + for key in record.keys(): + value = record[key] + if isinstance(value, graph.Node): + node_id = value._properties["id"] + description = value._properties["description"] + vertex = Vertex(vid=node_id, name=node_id, description=description) + mg.upsert_vertex(vertex) + elif isinstance(value, graph.Relationship): + edge_nodes = value.nodes + prop_id = value._properties["id"] + assert edge_nodes and edge_nodes[0] and edge_nodes[1] + src_id = edge_nodes[0]._properties["id"] + dst_id = edge_nodes[1]._properties["id"] + description = value._properties["description"] + edge = Edge( + sid=src_id, tid=dst_id, name=prop_id, description=description + ) + mg.append_edge(edge) + elif isinstance(value, graph.Path): + nodes = list(record["p"].nodes) + rels = list(record["p"].relationships) + formatted_path = [] + for i in range(len(nodes)): + formatted_path.append({ + "id": nodes[i]._properties["id"], + "description": nodes[i]._properties["description"], + }) + if i < len(rels): + formatted_path.append({ + "id": rels[i]._properties["id"], + "description": rels[i]._properties["description"], + }) + for i in range(0, len(formatted_path), 2): + mg.upsert_vertex( + Vertex( + vid=formatted_path[i]["id"], + name=formatted_path[i]["id"], + description=formatted_path[i]["description"], + ) + ) + if i + 2 < len(formatted_path): + mg.append_edge( + Edge( + sid=formatted_path[i]["id"], + tid=formatted_path[i + 2]["id"], + name=formatted_path[i + 1]["id"], + description=formatted_path[i + 1]["description"], + ) + ) + else: + vertex = Vertex( + vid="json_node", name="json_node", description=value + ) + mg.upsert_vertex(vertex) + yield mg + + def _get_nodes_edges_from_queried_data( + self, + data: List[Dict[str, Any]], + white_prop_list: List[str], + ) -> Tuple[List[Vertex], List[Edge]]: + """Format the query data. + + Args: + data: The data to be formatted. + white_prop_list: The white list of properties. + + Returns: + Tuple[List[Vertex], List[Edge]]: The formatted vertices and edges. + """ + vertex_list: List[Vertex] = [] + edge_list: List[Edge] = [] + + # Remove id, src_id, dst_id and name from the white list + # to avoid duplication in the initialisation of the vertex and edge + _white_list = [ + prop + for prop in white_prop_list + if prop not in ["id", "src_id", "dst_id", "name"] + ] + + from neo4j import graph + + def filter_properties( + properties: dict[str, Any], white_list: List[str] + ) -> Dict[str, Any]: + """Filter the properties. + + It will remove the properties that are not in the white list. + The expected propertities are: + entity_properties = ["id", "name", "description", "_document_id", + "_chunk_id", "_community_id"] + edge_properties = ["id", "name", "description", "_chunk_id"] + """ + return { + key: value + for key, value in properties.items() + if (not key.startswith("_") and key not in ["id", "name"]) + or key in white_list + } + + # Parse the data to nodes and relationships + for record in data: + for value in record.values(): + if isinstance(value, graph.Node): + assert value._properties.get("id") + vertex = Vertex( + vid=value._properties.get("id", ""), + name=value._properties.get("name"), + **filter_properties(value._properties, _white_list), + ) + if vertex not in vertex_list: + # TODO: Do we really need to check it every time? + vertex_list.append(vertex) + elif isinstance(value, graph.Relationship): + for node in value.nodes: # num of nodes is 2 + assert node and node._properties + vertex = Vertex( + vid=node._properties.get("id", ""), + name=node._properties.get("name"), + **filter_properties(node._properties, _white_list), + ) + if vertex not in vertex_list: + vertex_list.append(vertex) + + assert value.nodes and value.nodes[0] and value.nodes[1] + edge = Edge( + sid=value.nodes[0]._properties.get("id", ""), + tid=value.nodes[1]._properties.get("id", ""), + name=value._properties.get("name", ""), + **filter_properties(value._properties, _white_list), + ) + if edge not in edge_list: + edge_list.append(edge) + elif isinstance(value, graph.Path): + for rel in value.relationships: + for node in rel.nodes: # num of nodes is 2 + assert node and node._properties + vertex = Vertex( + vid=node._properties.get("id", ""), + name=node._properties.get("name"), + **filter_properties(node._properties, _white_list), + ) + if vertex not in vertex_list: + vertex_list.append(vertex) + + assert rel.nodes and rel.nodes[0] and rel.nodes[1] + edge = Edge( + sid=rel.nodes[0]._properties.get("id", ""), + tid=rel.nodes[1]._properties.get("id", ""), + name=rel._properties.get("name", ""), + **filter_properties(rel._properties, _white_list), + ) + if edge not in edge_list: + edge_list.append(edge) + + else: # json_node + vertex = Vertex( + vid="json_node", + name="json_node", + **filter_properties({"description": value}, _white_list), + ) + if vertex not in vertex_list: + vertex_list.append(vertex) + return vertex_list, edge_list + + def _parser(self, entity_list: List[Dict[str, Any]]) -> str: + """Parse entities to string.""" + formatted_nodes = [ + "{" + + ", ".join( + f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}" + for k, v in node.items() + ) + + "}" + for node in entity_list + ] + return f"""{", ".join(formatted_nodes)}""" diff --git a/dbgpt/storage/knowledge_graph/community_summary.py b/dbgpt/storage/knowledge_graph/community_summary.py index a5bf272ac..cab298a00 100644 --- a/dbgpt/storage/knowledge_graph/community_summary.py +++ b/dbgpt/storage/knowledge_graph/community_summary.py @@ -2,14 +2,16 @@ import logging import os +import uuid from typing import List, Optional from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer from dbgpt.rag.transformer.graph_extractor import GraphExtractor +from dbgpt.storage.graph_store.graph import GraphElemType, MemoryGraph from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore -from dbgpt.storage.knowledge_graph.community.factory import CommunityStoreAdapterFactory +from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory from dbgpt.storage.knowledge_graph.knowledge_graph import ( BuiltinKnowledgeGraph, BuiltinKnowledgeGraphConfig, @@ -27,7 +29,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig): model_config = ConfigDict(arbitrary_types_allowed=True) vector_store_type: str = Field( - default="Chroma", description="The type of vector store." + default="Chroma", + description="The type of vector store.", ) user: Optional[str] = Field( default=None, @@ -36,7 +39,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig): password: Optional[str] = Field( default=None, description=( - "The password of vector store, if not set, will use the default password." + "The password of vector store, " + "if not set, will use the default password." ), ) extract_topk: int = Field( @@ -120,7 +124,7 @@ def community_store_configure(name: str, cfg: VectorStoreConfig): cfg.score_threshold = self._community_score_threshold self._community_store = CommunityStore( - CommunityStoreAdapterFactory.create(self._graph_store), + GraphStoreAdapterFactory.create(self._graph_store), CommunitySummarizer(self._llm_client, self._model_name), VectorStoreFactory.create( self._vector_store_type, @@ -135,21 +139,165 @@ def get_config(self) -> BuiltinKnowledgeGraphConfig: async def aload_document(self, chunks: List[Chunk]) -> List[str]: """Extract and persist graph.""" - # todo add doc node - for chunk in chunks: - # todo add chunk node - # todo add relation doc-chunk + data_list = self._parse_chunks(chunks) # parse the chunks by def _lod_doc_graph + graph_of_all = MemoryGraph() - # extract graphs and save - graphs = await self._graph_extractor.extract(chunk.content) - for graph in graphs: - self._graph_store.insert_graph(graph) + # Support graph search by the document and the chunks + if self._graph_store.get_config().enable_document_graph: + doc_vid = str(uuid.uuid4()) + doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node") + for chunk_index, chunk in enumerate(data_list): + if chunk["parent_id"] != "document": + # chunk -> include -> chunk + graph_of_all.upsert_vertex_and_edge( + src_vid=chunk["parent_id"], + src_name=chunk["parent_title"], + src_props={ + "vertex_type": GraphElemType.CHUNK.value, + "content": chunk["content"], + }, + dst_vid=chunk["id"], + dst_name=chunk["title"], + dst_props={ + "vertex_type": GraphElemType.CHUNK.value, + "content": chunk["content"], + }, + edge_name=GraphElemType.INCLUDE.value, + edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value, + ) + else: + # document -> include -> chunk + graph_of_all.upsert_vertex_and_edge( + src_vid=doc_vid, + src_name=doc_name, + src_props={ + "vertex_type": GraphElemType.DOCUMENT.value, + "content": "", + }, + dst_vid=chunk["id"], + dst_name=chunk["title"], + dst_props={ + "vertex_type": GraphElemType.CHUNK.value, + "content": chunk["content"], + }, + edge_name=GraphElemType.INCLUDE.value, + edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value, + ) + + # chunk -> next -> chunk + if chunk_index >= 1: + graph_of_all.upsert_vertex_and_edge( + src_vid=data_list[chunk_index - 1]["id"], + src_name=data_list[chunk_index - 1]["title"], + src_props={ + "vertex_type": GraphElemType.CHUNK.value, + "content": data_list[chunk_index - 1]["content"], + }, + dst_vid=chunk["id"], + dst_name=chunk["title"], + dst_props={ + "vertex_type": GraphElemType.CHUNK.value, + "content": chunk["content"], + }, + edge_name=GraphElemType.NEXT.value, + edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value, + ) + + # Support knowledge graph search by the entities and the relationships + if self._graph_store.get_config().enable_triplet_graph: + for chunk_index, chunk in enumerate(data_list): + # TODO: Use asyncio to extract graph to accelerate the process + # (attention to the CAP of the graph db) + + graphs: List[MemoryGraph] = await self._graph_extractor.extract( + chunk["content"] + ) + + for graph in graphs: + graph_of_all.upsert_graph(graph) + + # chunk -> include -> entity + if self._graph_store.get_config().enable_document_graph: + for vertex in graph.vertices(): + graph_of_all.upsert_vertex_and_edge( + src_vid=chunk["id"], + src_name=chunk["title"], + src_props={ + "vertex_type": GraphElemType.CHUNK.value, + "content": chunk["content"], + }, + dst_vid=vertex.vid, + dst_name=vertex.name, + dst_props={ + "vertex_type": GraphElemType.ENTITY.value, + "description": vertex.props.get("description", ""), + }, # note: description is only used for the entity + edge_name=GraphElemType.INCLUDE.value, + edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value, + ) + + self._graph_store_apdater.upsert_graph(graph_of_all) + + # use asyncio.gather + # tasks = [self._graph_extractor.extract(chunk.content) for chunk in chunks] + # results = await asyncio.gather(*tasks) + # for result in results: + # self._graph_store_apdater.upsert_graph(result[0]) # build communities and save + await self._community_store.build_communities() return [chunk.chunk_id for chunk in chunks] + def _parse_chunks(slef, chunks: List[Chunk]): + """Parse the chunks by anlyzing the markdown chunks.""" + # TODO: Need to refact. + data = [] + for chunk_index, chunk in enumerate(chunks): + parent = None + directory_keys = list(chunk.metadata.keys())[:-1] + parent_level = directory_keys[-2] if len(directory_keys) > 1 else None + current_level = directory_keys[-1] if directory_keys else "Header0" + + chunk_data = { + "id": chunk.chunk_id, + "title": chunk.metadata.get(current_level, "none_header_chunk"), + "directory_keys": directory_keys, + "level": current_level, + "content": chunk.content, + "parent_id": None, + "parent_title": None, + "type": "chunk", + "chunk_index": chunk_index, + } + + # Find the parent chunk + if parent_level: + for parent_direct in reversed(directory_keys[:-1]): + parent_titile = chunk.metadata.get(parent_direct, None) + for n in range(chunk_index - 1, -1, -1): + metadata = chunks[n].metadata + keys = list(metadata.keys())[:-1] + if ( + metadata + and parent_direct == keys[-1] + and parent_titile == metadata.get(parent_direct) + ): + parent = chunks[n] + chunk_data["parent_id"] = parent.chunk_id + chunk_data["parent_title"] = parent_titile + break + if chunk_index - n > len(directory_keys): + break + if chunk_data["parent_id"]: + break + + if not chunk_data["parent_id"]: + chunk_data["parent_id"] = "document" + data.append(chunk_data) + return data + async def asimilar_search_with_scores( self, text, @@ -158,7 +306,7 @@ async def asimilar_search_with_scores( filters: Optional[MetadataFilters] = None, ) -> List[Chunk]: """Retrieve relevant community summaries.""" - # global search: retrieve relevant community summaries + # Global search: retrieve relevant community summaries communities = await self._community_store.search_communities(text) summaries = [ f"Section {i + 1}:\n{community.summary}" @@ -166,16 +314,53 @@ async def asimilar_search_with_scores( ] context = "\n".join(summaries) if summaries else "" - # local search: extract keywords and explore subgraph - keywords = await self._keyword_extractor.extract(text) - subgraph = self._graph_store.explore(keywords, limit=topk).format() - logger.info(f"Search subgraph from {len(keywords)} keywords") + keywords: List[str] = await self._keyword_extractor.extract(text) - if not summaries and not subgraph: + # Local search: extract keywords and explore subgraph + subgraph = MemoryGraph() + subgraph_for_doc = MemoryGraph() + + enable_triplet_graph = self._graph_store.get_config().enable_triplet_graph + enable_document_graph = self._graph_store.get_config().enable_document_graph + + if enable_triplet_graph: + subgraph: MemoryGraph = self._graph_store_apdater.explore( + subs=keywords, limit=10, search_scope="knowledge_graph" + ) + + if enable_document_graph: + keywords_for_document_graph = keywords + for vertex in subgraph.vertices(): + keywords_for_document_graph.append(vertex.name) + + subgraph_for_doc = self._graph_store_apdater.explore( + subs=keywords_for_document_graph, + limit=5, + search_scope="document_graph", + ) + else: + if enable_document_graph: + subgraph_for_doc = self._graph_store_apdater.explore( + subs=keywords, + limit=10, + search_scope="document_graph", + ) + + knowledge_graph_str = subgraph.format() + knowledge_graph_for_doc_str = subgraph_for_doc.format() + + logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}") + + if not (summaries or knowledge_graph_str or knowledge_graph_for_doc_str): return [] # merge search results into context - content = HYBRID_SEARCH_PT_CN.format(context=context, graph=subgraph) + content = HYBRID_SEARCH_PT_CN.format( + context=context, + knowledge_graph=knowledge_graph_str, + knowledge_graph_for_doc=knowledge_graph_for_doc_str, + ) + logger.info(f"Final GraphRAG queried prompt:\n{content}") return [Chunk(content=content)] def truncate(self) -> List[str]: @@ -200,174 +385,179 @@ def delete_vector_name(self, index_name: str): self._graph_extractor.drop() -HYBRID_SEARCH_PT_CN = ( - "## 角色\n" - "你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息," - "准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。" - "\n" - "## 技能\n" - "### 技能 1: 上下文理解\n" - "- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。\n" - "- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。\n" - "- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。" - "### 技能 2: 知识图谱理解\n" - "- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息" - "和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:\n" - "```" - "* 实体信息格式:\n" - "- (实体名)\n" - "- (实体名:实体描述)\n" - "- (实体名:实体属性表)\n" - "- (文本块ID:文档块内容)\n" - "- (目录ID:目录名)\n" - "- (文档ID:文档名称)\n" - "\n" - "* 关系信息的格式:\n" - "- (来源实体名)-[关系名]->(目标实体名)\n" - "- (来源实体名)-[关系名:关系描述]->(目标实体名)\n" - "- (来源实体名)-[关系名:关系属性表]->(目标实体名)\n" - "- (文本块实体)-[包含]->(实体名)\n" - "- (目录ID)-[包含]->(文本块实体)\n" - "- (目录ID)-[包含]->(子目录ID)\n" - "- (文档ID)-[包含]->(文本块实体)\n" - "- (文档ID)-[包含]->(目录ID)\n" - "```" - "- 正确地将关系信息中的实体名/ID与实体信息关联,还原出图结构。" - "- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。\n" - "\n" - "## 约束条件\n" - "- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。\n" - "- 若[知识图谱]没有提供信息,此时应根据[上下文]提供的信息回答问题。" - "- 确保以第三人称书写,从客观角度结合[上下文]和[知识图谱]表达的信息回答问题。\n" - "- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。\n" - "- 避免使用停用词和过于常见的词汇。\n" - "\n" - "## 参考案例\n" - "```\n" - "[上下文]:\n" - "Section 1:\n" - "菲尔・贾伯的大儿子叫雅各布・贾伯。\n" - "Section 2:\n" - "菲尔・贾伯的小儿子叫比尔・贾伯。\n" - "[知识图谱]:\n" - "Entities:\n" - "(菲尔・贾伯#菲尔兹咖啡创始人)\n" - "(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)\n" - "(雅各布・贾伯#菲尔・贾伯的儿子)\n" - "(美国多地#菲尔兹咖啡的扩展地区)\n" - "\n" - "Relationships:\n" - "(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)\n" - "(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)\n" - "(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)\n" - "(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)\n" - "(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)\n" - "```\n" - "\n" - "----\n" - "\n" - "接下来的[上下文]和[知识图谱]的信息,可以帮助你回答更好地用户的问题。\n" - "\n" - "[上下文]:\n" - "{context}\n" - "\n" - "[知识图谱]:\n" - "{graph}\n" - "\n" -) +HYBRID_SEARCH_PT_CN = """## 角色 +你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息, +准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。 -HYBRID_SEARCH_PT_EN = ( - "## Role\n" - "You excel at combining the information provided in the [Context] with " - "information from the [KnowledgeGraph] to accurately and appropriately " - "answer user questions, ensuring that you do not output information " - "unrelated to the context and knowledge graph.\n" - "\n" - "## Skills\n" - "### Skill 1: Context Understanding\n" - "- Accurately understand the information provided in the [Context], " - "which may be divided into several sections.\n" - "- Each section in the context will start with [Section] " - "and may be numbered as needed.\n" - "- The context provides a summary description most relevant to the user’s " - "question, and it should be used wisely." - "### Skill 2: Knowledge Graph Understanding\n" - "- Accurately identify entity information in the [Entities:] section and " - "relationship information in the [Relationships:] section " - "of the [KnowledgeGraph]. The general format for entity " - "and relationship information is:\n" - "```" - "* Entity Information Format:\n" - "- (entity_name)\n" - "- (entity_name: entity_description)\n" - "- (entity_name: entity_property_map)\n" - "- (chunk_id: chunk_content)\n" - "- (catalog_id: catalog_name)\n" - "- (document_id: document_name)\n" - "\n" - "* Relationship Information Format:\n" - "- (source_entity_name)-[relationship_name]->(target_entity_name)\n" - "- (source_entity_name)-[relationship_name: relationship_description]->" - "(target_entity_name)\n" - "- (source_entity_name)-[relationship_name: relationship_property_map]->" - "(target_entity_name)\n" - "- (chunk_id)-[Contains]->(entity_name)\n" - "- (catalog_id)-[Contains]->(chunk_id)\n" - "- (catalog_id)-[Contains]->(sub_catalog_id)\n" - "- (document_id)-[Contains]->(chunk_id)\n" - "- (document_id)-[Contains]->(catalog_id)\n" - "```" - "- Correctly associate entity names/IDs in the relationship information " - "with entity information to restore the graph structure." - "- Use the information expressed by the graph structure as detailed " - "context for the user's query to assist in generating better answers.\n" - "\n" - "## Constraints\n" - "- Don't describe your thought process in the answer, provide the answer " - "to the user's question directly without generating irrelevant information." - "- If the [KnowledgeGraph] does not provide information, you should answer " - "the question based on the information provided in the [Context]." - "- Ensure to write in the third person, responding to questions from " - "an objective perspective based on the information combined from the " - "[Context] and the [KnowledgeGraph].\n" - "- If the provided information is contradictory, resolve the " - "contradictions and provide a single, coherent description.\n" - "- Avoid using stop words and overly common vocabulary.\n" - "\n" - "## Reference Example\n" - "```\n" - "[Context]:\n" - "Section 1:\n" - "Phil Schiller's eldest son is Jacob Schiller.\n" - "Section 2:\n" - "Phil Schiller's youngest son is Bill Schiller.\n" - "[KnowledgeGraph]:\n" - "Entities:\n" - "(Phil Jaber#Founder of Philz Coffee)\n" - "(Philz Coffee#Coffee brand founded in Berkeley, California)\n" - "(Jacob Jaber#Son of Phil Jaber)\n" - "(Multiple locations in the USA#Expansion regions of Philz Coffee)\n" - "\n" - "Relationships:\n" - "(Phil Jaber#Created#Philz Coffee" - "#Founded in Berkeley, California in 1978)\n" - "(Philz Coffee#Located in#Berkeley, California" - "#Founding location of Philz Coffee)\n" - "(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)\n" - "(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)\n" - "(Philz Coffee#Expanded to#Multiple locations in the USA" - "#Expansion regions of Philz Coffee)\n" - "```\n" - "\n" - "----\n" - "\n" - "The following information from the [Context] and [KnowledgeGraph] can " - "help you better answer user questions.\n" - "\n" - "[Context]:\n" - "{context}\n" - "\n" - "[KnowledgeGraph]:\n" - "{graph}\n" - "\n" -) +## 技能 +### 技能 1: 上下文理解 +- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。 +- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。 +- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。 +### 技能 2: 知识图谱理解 +- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为: +``` +* 实体信息格式: +- (实体名) +- (实体名:实体描述) +- (实体名:实体属性表) +- (文本块ID:文档块内容) +- (目录ID:目录名) +- (文档ID:文档名称) + +* 关系信息的格式: +- (来源实体名)-[关系名]->(目标实体名) +- (来源实体名)-[关系名:关系描述]->(目标实体名) +- (来源实体名)-[关系名:关系属性表]->(目标实体名) +- (文本块实体)-[包含]->(实体名) +- (目录ID)-[包含]->(文本块实体) +- (目录ID)-[包含]->(子目录ID) +- (文档ID)-[包含]->(文本块实体) +- (文档ID)-[包含]->(目录ID) +``` +- 正确地将关系信息中的实体名/ID与实体信息关联,还原出图结构。 +- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。 + + +## 约束条件 +- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。 +- 若[知识图谱]或者[知识库原文]没有提供信息,此时应根据[上下文]提供的信息回答问题。 +- 确保以第三人称书写,从客观角度结合[上下文]、[知识图谱]和[知识库原文]表达的信息回答问题。 +- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。 +- 避免使用停用词和过于常见的词汇。 + +## 参考案例 +``` +[上下文]: +Section 1: +菲尔・贾伯的大儿子叫雅各布・贾伯。 +Section 2: +菲尔・贾伯的小儿子叫比尔・贾伯。 + +[知识图谱]: +Entities: +(菲尔・贾伯#菲尔兹咖啡创始人) +(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌) +(雅各布・贾伯#菲尔・贾伯的儿子) +(美国多地#菲尔兹咖啡的扩展地区) + +Relationships: +(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立) +(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点) +(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子) +(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官) +(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围) + +[知识库原文]: +... +``` + +---- + +接下来的[上下文]、[知识图谱]和[知识库原文]的信息,可以帮助你回答更好地用户的问题。 + +[上下文]: +{context} + +[知识图谱]: +{knowledge_graph} + +[知识库原文] +{knowledge_graph_for_doc} +""" # noqa: E501 + +HYBRID_SEARCH_PT_EN = """## Role +You excel at combining the information provided in the [Context] with +information from the [KnowledgeGraph] to accurately and appropriately +answer user questions, ensuring that you do not output information +unrelated to the context and knowledge graph. + +## Skills +### Skill 1: Context Understanding +- Accurately understand the information provided in the [Context], +which may be divided into several sections. +- Each section in the context will start with [Section] +and may be numbered as needed. +- The context provides a summary description most relevant to the user's +question, and it should be used wisely. +### Skill 2: Knowledge Graph Understanding +- Accurately identify entity information in the [Entities:] section and +relationship information in the [Relationships:] section +of the [KnowledgeGraph]. The general format for entity +and relationship information is: +``` +* Entity Information Format: +- (entity_name) +- (entity_name: entity_description) +- (entity_name: entity_property_map) +- (chunk_id: chunk_content) +- (catalog_id: catalog_name) +- (document_id: document_name) + +* Relationship Information Format: +- (source_entity_name)-[relationship_name]->(target_entity_name) +- (source_entity_name)-[relationship_name: relationship_description]->(target_entity_name) +- (source_entity_name)-[relationship_name: relationship_property_map]->(target_entity_name) +- (chunk_id)-[Contains]->(entity_name) +- (catalog_id)-[Contains]->(chunk_id) +- (catalog_id)-[Contains]->(sub_catalog_id) +- (document_id)-[Contains]->(chunk_id) +- (document_id)-[Contains]->(catalog_id) +``` +- Correctly associate entity names/IDs in the relationship information +with entity information to restore the graph structure. +- Use the information expressed by the graph structure as detailed +context for the user's query to assist in generating better answers. + +## Constraints +- Don't describe your thought process in the answer, provide the answer +to the user's question directly without generating irrelevant information. +- If the [KnowledgeGraph] or [Knowledge base original text] does not provide information, you should answer +the question based on the information provided in the [Context]. +- Ensure to write in the third person, responding to questions from +an objective perspective based on the information combined from the +[Context], the [KnowledgeGraph] and the [Knowledge base original text]. +- If the provided information is contradictory, resolve the +contradictions and provide a single, coherent description. +- Avoid using stop words and overly common vocabulary. + +## Reference Example +``` +[Context]: +Section 1: +Phil Schiller's eldest son is Jacob Schiller. +Section 2: +Phil Schiller's youngest son is Bill Schiller. + +[KnowledgeGraph]: +Entities: +(Phil Jaber#Founder of Philz Coffee) +(Philz Coffee#Coffee brand founded in Berkeley, California) +(Jacob Jaber#Son of Phil Jaber) +(Multiple locations in the USA#Expansion regions of Philz Coffee) + +Relationships: +(Phil Jaber#Created#Philz Coffee#Founded in Berkeley, California in 1978) +(Philz Coffee#Located in#Berkeley, California#Founding location of Philz Coffee) +(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber) +(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005) +(Philz Coffee#Expanded to#Multiple locations in the USA#Expansion regions of Philz Coffee) + +[Knowledge base original text] +... +``` + +---- + +The following information from the [Context], [KnowledgeGraph] and [Knowledge base original text] +can help you better answer user questions. + +[Context]: +{context} + +[KnowledgeGraph]: +{knowledge_graph} + +[Knowledge base original text] +{knowledge_graph_for_doc} +""" # noqa: E501 diff --git a/dbgpt/storage/knowledge_graph/knowledge_graph.py b/dbgpt/storage/knowledge_graph/knowledge_graph.py index 066e2667d..10d9134aa 100644 --- a/dbgpt/storage/knowledge_graph/knowledge_graph.py +++ b/dbgpt/storage/knowledge_graph/knowledge_graph.py @@ -1,4 +1,5 @@ """Knowledge graph class.""" + import asyncio import logging import os @@ -12,6 +13,8 @@ from dbgpt.storage.graph_store.factory import GraphStoreFactory from dbgpt.storage.graph_store.graph import Graph from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig +from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter +from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory from dbgpt.storage.vector_store.filters import MetadataFilters logger = logging.getLogger(__name__) @@ -46,9 +49,10 @@ def __init__(self, config: BuiltinKnowledgeGraphConfig): self._model_name = config.model_name self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name) self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name) - self._graph_store = self.__init_graph_store(config) + self._graph_store: GraphStoreBase = self.__init_graph_store(config) + self._graph_store_apdater: GraphStoreAdapter = self.__init_graph_store_adapter() - def __init_graph_store(self, config) -> GraphStoreBase: + def __init_graph_store(self, config: BuiltinKnowledgeGraphConfig) -> GraphStoreBase: def configure(cfg: GraphStoreConfig): cfg.name = config.name cfg.embedding_fn = config.embedding_fn @@ -56,6 +60,9 @@ def configure(cfg: GraphStoreConfig): graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.graph_store_type return GraphStoreFactory.create(graph_store_type, configure) + def __init_graph_store_adapter(self): + return GraphStoreAdapterFactory.create(self._graph_store) + def get_config(self) -> BuiltinKnowledgeGraphConfig: """Get the knowledge graph config.""" return self._config @@ -63,10 +70,10 @@ def get_config(self) -> BuiltinKnowledgeGraphConfig: def load_document(self, chunks: List[Chunk]) -> List[str]: """Extract and persist triplets to graph store.""" - async def process_chunk(chunk): + async def process_chunk(chunk: Chunk): triplets = await self._triplet_extractor.extract(chunk.content) for triplet in triplets: - self._graph_store.insert_triplet(*triplet) + self._graph_store_apdater.insert_triplet(*triplet) logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}") return chunk.chunk_id @@ -89,7 +96,7 @@ async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignor for chunk in chunks: triplets = await self._triplet_extractor.extract(chunk.content) for triplet in triplets: - self._graph_store.insert_triplet(*triplet) + self._graph_store_apdater.insert_triplet(*triplet) logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}") return [chunk.chunk_id for chunk in chunks] @@ -116,7 +123,8 @@ async def asimilar_search_with_scores( # extract keywords and explore graph store keywords = await self._keyword_extractor.extract(text) - subgraph = self._graph_store.explore(keywords, limit=topk).format() + subgraph = self._graph_store_apdater.explore(keywords, limit=topk).format() + logger.info(f"Search subgraph from {len(keywords)} keywords") if not subgraph: @@ -147,12 +155,12 @@ async def asimilar_search_with_scores( def query_graph(self, limit: Optional[int] = None) -> Graph: """Query graph.""" - return self._graph_store.get_full_graph(limit) + return self._graph_store_apdater.get_full_graph(limit) def truncate(self) -> List[str]: """Truncate knowledge graph.""" logger.info(f"Truncate graph {self._config.name}") - self._graph_store.truncate() + self._graph_store_apdater.truncate() logger.info("Truncate keyword extractor") self._keyword_extractor.truncate() @@ -165,10 +173,15 @@ def truncate(self) -> List[str]: def delete_vector_name(self, index_name: str): """Delete vector name.""" logger.info(f"Drop graph {index_name}") - self._graph_store.drop() + self._graph_store_apdater.drop() logger.info("Drop keyword extractor") self._keyword_extractor.drop() logger.info("Drop triplet extractor") self._triplet_extractor.drop() + + def delete_by_ids(self, ids: str) -> List[str]: + """Delete by ids.""" + self._graph_store_apdater.delete_document(chunk_ids=ids) + return [] diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py index 663ee779d..427bba685 100644 --- a/dbgpt/storage/metadata/db_manager.py +++ b/dbgpt/storage/metadata/db_manager.py @@ -1,4 +1,5 @@ """The database manager.""" + from __future__ import annotations import logging diff --git a/docs/docs/cookbook/rag/graph_rag_app_develop.md b/docs/docs/cookbook/rag/graph_rag_app_develop.md index a6b72273f..b0b00bf8f 100644 --- a/docs/docs/cookbook/rag/graph_rag_app_develop.md +++ b/docs/docs/cookbook/rag/graph_rag_app_develop.md @@ -10,7 +10,7 @@ You can refer to the python example file `DB-GPT/examples/rag/graph_rag_example. First, you need to install the `dbgpt` library. ```bash -pip install "dbgpt[rag]>=0.6.0" +pip install "dbgpt[graph_rag]>=0.6.1" ```` ### Prepare Graph Database @@ -112,7 +112,9 @@ TUGRAPH_HOST=127.0.0.1 TUGRAPH_PORT=7687 TUGRAPH_USERNAME=admin TUGRAPH_PASSWORD=73@TuGraph -GRAPH_COMMUNITY_SUMMARY_ENABLED=True +ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary +ENABLE_TRIPLET_GRAPH=True # enable the graph search for the triplets +ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks ``` @@ -250,23 +252,23 @@ Performance testing is based on the `gpt-4o-mini` model. #### Indexing Performance -| | DB-GPT | GraphRAG(microsoft) | -|----------|----------|------------------------| -| Document Tokens | 42631 | 42631 | -| Graph Size | 808 nodes, 1170 edges | 779 nodes, 967 edges | -| Prompt Tokens | 452614 | 744990 | -| Completion Tokens | 48325 | 227230 | -| Total Tokens | 500939 | 972220 | +| | DB-GPT | GraphRAG(microsoft) | +| ----------------- | --------------------- | -------------------- | +| Document Tokens | 42631 | 42631 | +| Graph Size | 808 nodes, 1170 edges | 779 nodes, 967 edges | +| Prompt Tokens | 452614 | 744990 | +| Completion Tokens | 48325 | 227230 | +| Total Tokens | 500939 | 972220 | #### Querying Performance **Global Search** -| | DB-GPT | GraphRAG(microsoft) | -|----------|----------|------------------------| -| Time | 8s | 40s | -| Tokens| 7432 | 63317 | +| | DB-GPT | GraphRAG(microsoft) | +| ------ | ------ | ------------------- | +| Time | 8s | 40s | +| Tokens | 7432 | 63317 | **Question** ``` @@ -304,10 +306,10 @@ Performance testing is based on the `gpt-4o-mini` model. **Local Search** -| | DB-GPT | GraphRAG(microsoft) | -|----------|----------|------------------------| -| Time | 15s | 15s | -| Tokens| 9230 | 11619 | +| | DB-GPT | GraphRAG(microsoft) | +| ------ | ------ | ------------------- | +| Time | 15s | 15s | +| Tokens | 9230 | 11619 | **Question** @@ -352,3 +354,28 @@ DB-GPT社区与TuGraph社区的比较 总结 总体而言,DB-GPT社区和TuGraph社区在社区贡献、生态系统和开发者参与等方面各具特色。DB-GPT社区更侧重于AI应用的多样性和组织间的合作,而TuGraph社区则专注于图数据的高效管理和分析。两者的共同点在于都强调了开源和社区合作的重要性,推动了各自领域的技术进步和应用发展。 ``` + +### Latest Updates + +In version 0.6.1 of DB-GPT, we have added a new feature: +- Retrieval of triplets with the **retrieval of document structure** + +We have expanded the definition scope of 'Graph' in GraphRAG: +``` +Knowledge Graph = Triplets Graph + Document Structure Graph +``` + +

+ +

+ +How? + +We decompose standard format files (currently best support for Markdown files) into a directed graph based on their hierarchy and layout information, and store it in a graph database. In this graph: +- Each node represents a chunk of the file +- Each edge represents the structural relationship between different chunks in the original document +- Merge the document structure graph to the triplets graph + +What is the next? + +We aim to construct a more complex Graph that covers more comprehensive information to support more sophisticated retrieval algorithms in our GraphRAG. \ No newline at end of file diff --git a/docs/docs/cookbook/rag/image_graphrag_0_6_1.png b/docs/docs/cookbook/rag/image_graphrag_0_6_1.png new file mode 100644 index 000000000..566bf4bd0 Binary files /dev/null and b/docs/docs/cookbook/rag/image_graphrag_0_6_1.png differ diff --git a/docs/static/img/chat_knowledge/graph_rag/image_graphrag_0_6_1.png b/docs/static/img/chat_knowledge/graph_rag/image_graphrag_0_6_1.png new file mode 100644 index 000000000..566bf4bd0 Binary files /dev/null and b/docs/static/img/chat_knowledge/graph_rag/image_graphrag_0_6_1.png differ diff --git a/tests/intetration_tests/datasource/test_conn_tugraph.py b/tests/intetration_tests/datasource/test_conn_tugraph.py index eafed88a6..f323c48f2 100644 --- a/tests/intetration_tests/datasource/test_conn_tugraph.py +++ b/tests/intetration_tests/datasource/test_conn_tugraph.py @@ -1,4 +1,5 @@ import pytest + from dbgpt.datasource.conn_tugraph import TuGraphConnector # Set database connection parameters. @@ -21,10 +22,10 @@ def connector(): def test_get_table_names(connector): """Test retrieving table names from the graph database.""" - table_names = connector.get_table_names() + vertex_tables, edge_tables = connector.get_table_names() # Verify the quantity of vertex and edge tables. - assert len(table_names["vertex_tables"]) == 5 - assert len(table_names["edge_tables"]) == 8 + assert len(vertex_tables) == 5 + assert len(edge_tables) == 8 def test_get_columns(connector): diff --git a/tests/intetration_tests/graph_store/test_memgraph_store.py b/tests/intetration_tests/graph_store/test_memgraph_store.py index 0bf066f0a..1c2893985 100644 --- a/tests/intetration_tests/graph_store/test_memgraph_store.py +++ b/tests/intetration_tests/graph_store/test_memgraph_store.py @@ -4,6 +4,9 @@ MemoryGraphStore, MemoryGraphStoreConfig, ) +from dbgpt.storage.knowledge_graph.community.memgraph_store_adapter import ( + MemGraphStoreAdapter, +) @pytest.fixture @@ -11,31 +14,37 @@ def graph_store(): yield MemoryGraphStore(MemoryGraphStoreConfig()) -def test_graph_store(graph_store): - graph_store.insert_triplet("A", "0", "A") - graph_store.insert_triplet("A", "1", "A") - graph_store.insert_triplet("A", "2", "B") - graph_store.insert_triplet("B", "3", "C") - graph_store.insert_triplet("B", "4", "D") - graph_store.insert_triplet("C", "5", "D") - graph_store.insert_triplet("B", "6", "E") - graph_store.insert_triplet("F", "7", "E") - graph_store.insert_triplet("E", "8", "F") - - subgraph = graph_store.explore(["A"]) +@pytest.fixture +def graph_store_adapter(graph_store: MemoryGraphStore): + memgraph_store_adapter = MemGraphStoreAdapter(graph_store) + yield memgraph_store_adapter + + +def test_graph_store(graph_store_adapter: MemGraphStoreAdapter): + graph_store_adapter.insert_triplet("A", "0", "A") + graph_store_adapter.insert_triplet("A", "1", "A") + graph_store_adapter.insert_triplet("A", "2", "B") + graph_store_adapter.insert_triplet("B", "3", "C") + graph_store_adapter.insert_triplet("B", "4", "D") + graph_store_adapter.insert_triplet("C", "5", "D") + graph_store_adapter.insert_triplet("B", "6", "E") + graph_store_adapter.insert_triplet("F", "7", "E") + graph_store_adapter.insert_triplet("E", "8", "F") + + subgraph = graph_store_adapter.explore(["A"]) print(f"\n{subgraph.format()}") assert subgraph.edge_count == 9 - graph_store.delete_triplet("A", "0", "A") - graph_store.delete_triplet("B", "4", "D") - subgraph = graph_store.explore(["A"]) + graph_store_adapter.delete_triplet("A", "0", "A") + graph_store_adapter.delete_triplet("B", "4", "D") + subgraph = graph_store_adapter.explore(["A"]) print(f"\n{subgraph.format()}") assert subgraph.edge_count == 7 - triplets = graph_store.get_triplets("B") + triplets = graph_store_adapter.get_triplets("B") print(f"\nTriplets of B: {triplets}") assert len(triplets) == 2 - schema = graph_store.get_schema() + schema = graph_store_adapter.get_schema() print(f"\nSchema: {schema}") assert len(schema) == 86 diff --git a/tests/intetration_tests/graph_store/test_tugraph_store.py b/tests/intetration_tests/graph_store/test_tugraph_store.py index 570869eb5..d02a2ca90 100644 --- a/tests/intetration_tests/graph_store/test_tugraph_store.py +++ b/tests/intetration_tests/graph_store/test_tugraph_store.py @@ -1,43 +1,52 @@ -# test_tugraph_store.py +# test_tugraph_tugraph_store_adapter.py import pytest from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig +from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import ( + TuGraphStoreAdapter, +) @pytest.fixture(scope="module") def store(): - config = TuGraphStoreConfig(name="TestGraph", summary_enabled=False) + config = TuGraphStoreConfig(name="TestGraph", enable_summary=False) store = TuGraphStore(config=config) yield store store.conn.close() -def test_insert_and_get_triplets(store): - store.insert_triplet("A", "0", "A") - store.insert_triplet("A", "1", "A") - store.insert_triplet("A", "2", "B") - store.insert_triplet("B", "3", "C") - store.insert_triplet("B", "4", "D") - store.insert_triplet("C", "5", "D") - store.insert_triplet("B", "6", "E") - store.insert_triplet("F", "7", "E") - store.insert_triplet("E", "8", "F") - triplets = store.get_triplets("A") +@pytest.fixture(scope="module") +def tugraph_store_adapter(store: TuGraphStore): + tugraph_store_adapter = TuGraphStoreAdapter(store) + yield tugraph_store_adapter + + +def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter): + tugraph_store_adapter.insert_triplet("A", "0", "A") + tugraph_store_adapter.insert_triplet("A", "1", "A") + tugraph_store_adapter.insert_triplet("A", "2", "B") + tugraph_store_adapter.insert_triplet("B", "3", "C") + tugraph_store_adapter.insert_triplet("B", "4", "D") + tugraph_store_adapter.insert_triplet("C", "5", "D") + tugraph_store_adapter.insert_triplet("B", "6", "E") + tugraph_store_adapter.insert_triplet("F", "7", "E") + tugraph_store_adapter.insert_triplet("E", "8", "F") + triplets = tugraph_store_adapter.get_triplets("A") assert len(triplets) == 2 - triplets = store.get_triplets("B") + triplets = tugraph_store_adapter.get_triplets("B") assert len(triplets) == 3 - triplets = store.get_triplets("C") + triplets = tugraph_store_adapter.get_triplets("C") assert len(triplets) == 1 - triplets = store.get_triplets("D") + triplets = tugraph_store_adapter.get_triplets("D") assert len(triplets) == 0 - triplets = store.get_triplets("E") + triplets = tugraph_store_adapter.get_triplets("E") assert len(triplets) == 1 - triplets = store.get_triplets("F") + triplets = tugraph_store_adapter.get_triplets("F") assert len(triplets) == 1 -def test_query(store): +def test_query(store: TuGraphStore): query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3" result = store.query(query) v_c = result.vertex_count @@ -45,18 +54,18 @@ def test_query(store): assert v_c == 3 and e_c == 3 -def test_explore(store): +def test_explore(tugraph_store_adapter: TuGraphStoreAdapter): subs = ["A", "B"] - result = store.explore(subs, depth=2, fan=None, limit=10) + result = tugraph_store_adapter.explore(subs, depth=2, fan=None, limit=10) v_c = result.vertex_count e_c = result.edge_count assert v_c == 5 and e_c == 5 -def test_delete_triplet(store): +def test_delete_triplet(tugraph_store_adapter: TuGraphStoreAdapter): subj = "A" rel = "0" obj = "B" - store.delete_triplet(subj, rel, obj) - triplets = store.get_triplets(subj) + tugraph_store_adapter.delete_triplet(subj, rel, obj) + triplets = tugraph_store_adapter.get_triplets(subj) assert len(triplets) == 0 diff --git a/tests/intetration_tests/graph_store/test_tugraph_store_with_summary.py b/tests/intetration_tests/graph_store/test_tugraph_store_with_summary.py index 0ca3de588..a1e8b31d0 100644 --- a/tests/intetration_tests/graph_store/test_tugraph_store_with_summary.py +++ b/tests/intetration_tests/graph_store/test_tugraph_store_with_summary.py @@ -1,18 +1,27 @@ import pytest +from dbgpt.storage.graph_store.graph import Edge, MemoryGraph, Vertex from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig -from dbgpt.storage.graph_store.graph import MemoryGraph, Edge, Vertex +from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import ( + TuGraphStoreAdapter, +) @pytest.fixture(scope="module") def store(): - config = TuGraphStoreConfig(name="TestSummaryGraph", summary_enabled=True) + config = TuGraphStoreConfig(name="TestSummaryGraph", enable_summary=True) store_instance = TuGraphStore(config=config) yield store_instance store_instance.conn.close() -def test_insert_graph(store): +@pytest.fixture(scope="module") +def graph_store_adapter(store: TuGraphStore): + tugraph_store_adapter = TuGraphStoreAdapter(store) + yield tugraph_store_adapter + + +def test_upsert_graph(tugraph_store_adapter: TuGraphStoreAdapter): graph = MemoryGraph() vertex_list = [ Vertex("A", "A", description="Vertex A", _document_id="Test doc"), @@ -35,22 +44,22 @@ def test_insert_graph(store): graph.upsert_vertex(vertex) for edge in edge_list: graph.append_edge(edge) - store.insert_graph(graph) + tugraph_store_adapter.upsert_graph(graph) -def test_leiden_query(store): +def test_leiden_query(store: TuGraphStore): query = "CALL db.plugin.callPlugin('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)" result = store.query(query) assert result.vertex_count == 1 -def test_query_node_and_edge(store): +def test_query_node_and_edge(store: TuGraphStore): query = 'MATCH (n)-[r]->(m) WHERE n._community_id = "0" RETURN n,r,m' result = store.query(query) assert result.vertex_count == 7 and result.edge_count == 6 -def test_stream_query_path(store): +def test_stream_query_path(store: TuGraphStore): query = 'MATCH p=(n)-[r:relation*2]->(m) WHERE n._community_id = "0" RETURN p' result = store.query(query) for v in result.vertices():