Skip to content

Commit

Permalink
feat: add document structure into GraphRAG (#2033)
Browse files Browse the repository at this point in the history
Co-authored-by: Appointat <[email protected]>
Co-authored-by: tpoisonooo <[email protected]>
Co-authored-by: vritser <[email protected]>
  • Loading branch information
4 people authored Oct 18, 2024
1 parent 811ce63 commit 88e3d12
Show file tree
Hide file tree
Showing 29 changed files with 1,910 additions and 936 deletions.
5 changes: 4 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,15 @@ EXECUTE_LOCAL_COMMANDS=False
#*******************************************************************#
VECTOR_STORE_TYPE=Chroma
GRAPH_STORE_TYPE=TuGraph
GRAPH_COMMUNITY_SUMMARY_ENABLED=True
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0

ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
ENABLE_TRIPLET_GRAPH=True # enable the graph search for triplets
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

Expand Down
4 changes: 2 additions & 2 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def __init__(self) -> None:

# Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
self.ENABLE_GRAPH_COMMUNITY_SUMMARY = (
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
)
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
Expand Down
12 changes: 8 additions & 4 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
KnowledgeDocumentEntity,
)
from dbgpt.app.knowledge.request.request import (
ChunkEditRequest,
ChunkQueryRequest,
DocumentQueryRequest,
DocumentRecallTestRequest,
Expand Down Expand Up @@ -650,12 +649,17 @@ def query_graph(self, space_name, limit):
{
"id": node.vid,
"communityId": node.get_prop("_community_id"),
"name": node.vid,
"type": "",
"name": node.name,
"type": node.get_prop("type") or "",
}
)
for edge in graph.edges():
res["edges"].append(
{"source": edge.sid, "target": edge.tid, "name": edge.name, "type": ""}
{
"source": edge.sid,
"target": edge.tid,
"name": edge.name,
"type": edge.get_prop("type") or "",
}
)
return res
40 changes: 23 additions & 17 deletions dbgpt/datasource/conn_tugraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""TuGraph Connector."""

import json
from typing import Dict, Generator, List, cast
from typing import Dict, Generator, List, Tuple, cast

from .base import BaseConnector

Expand All @@ -21,8 +21,7 @@ def __init__(self, driver, graph):
self._session = None

def create_graph(self, graph_name: str) -> None:
"""Create a new graph."""
# run the query to get vertex labels
"""Create a new graph in the database if it doesn't already exist."""
try:
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
Expand All @@ -32,10 +31,10 @@ def create_graph(self, graph_name: str) -> None:
f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)"
)
except Exception as e:
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}")
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e

def delete_graph(self, graph_name: str) -> None:
"""Delete a graph."""
"""Delete a graph in the database if it exists."""
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
exists = any(item["graph_name"] == graph_name for item in graph_list)
Expand All @@ -61,17 +60,20 @@ def from_uri_db(
"`pip install neo4j`"
) from err

def get_table_names(self) -> Dict[str, List[str]]:
def get_table_names(self) -> Tuple[List[str], List[str]]:
"""Get all table names from the TuGraph by Neo4j driver."""
# run the query to get vertex labels
with self._driver.session(database=self._graph) as session:
v_result = session.run("CALL db.vertexLabels()").data()
v_data = [table_name["label"] for table_name in v_result]
# Run the query to get vertex labels
raw_vertex_labels: Dict[str, str] = session.run(
"CALL db.vertexLabels()"
).data()
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]

# Run the query to get edge labels
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
edge_labels = [table_name["label"] for table_name in raw_edge_labels]

# run the query to get edge labels
e_result = session.run("CALL db.edgeLabels()").data()
e_data = [table_name["label"] for table_name in e_result]
return {"vertex_tables": v_data, "edge_tables": e_data}
return vertex_labels, edge_labels

def get_grants(self):
"""Get grants."""
Expand Down Expand Up @@ -100,7 +102,7 @@ def run(self, query: str, fetch: str = "all") -> List:
result = session.run(query)
return list(result)
except Exception as e:
raise Exception(f"Query execution failed: {e}")
raise Exception(f"Query execution failed: {e}\nQuery: {query}") from e

def run_stream(self, query: str) -> Generator:
"""Run GQL."""
Expand All @@ -109,11 +111,15 @@ def run_stream(self, query: str) -> Generator:
yield from result

def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get fields about specified graph.
"""Retrieve the column for a specified vertex or edge table in the graph db.
This function queries the schema of a given table (vertex or edge) and returns
detailed information about its columns (properties).
Args:
table_name (str): table name (graph name)
table_type (str): table type (vertex or edge)
Returns:
columns: List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
Expand Down Expand Up @@ -146,8 +152,8 @@ def get_indexes(self, table_name: str, table_type: str = "vertex") -> List[Dict]
"""Get table indexes about specified table.
Args:
table_name:(str) table name
table_type:(str'vertex' | 'edge'
table_name (str): table name
table_type (str): 'vertex' | 'edge'
Returns:
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
Expand Down
12 changes: 10 additions & 2 deletions dbgpt/rag/transformer/graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]
match = re.match(r"\((.*?)#(.*?)\)", line)
if match:
name, summary = [part.strip() for part in match.groups()]
graph.upsert_vertex(Vertex(name, description=summary))
graph.upsert_vertex(
Vertex(name, description=summary, vertex_type="entity")
)
elif current_section == "Relationships":
match = re.match(r"\((.*?)#(.*?)#(.*?)#(.*?)\)", line)
if match:
Expand All @@ -74,7 +76,13 @@ def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]
]
edge_count += 1
graph.append_edge(
Edge(source, target, name, description=summary)
Edge(
source,
target,
name,
description=summary,
edge_type="relation",
)
)

if limit and edge_count >= limit:
Expand Down
18 changes: 11 additions & 7 deletions dbgpt/rag/transformer/keyword_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""KeywordExtractor class."""

import logging
from typing import List, Optional

Expand Down Expand Up @@ -39,12 +40,15 @@ def __init__(self, llm_client: LLMClient, model_name: str):
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]:
keywords = set()

for part in text.split(";"):
for s in part.strip().split(","):
keyword = s.strip()
if keyword:
keywords.add(keyword)
if limit and len(keywords) >= limit:
return list(keywords)
lines = text.replace(":", "\n").split("\n")

for line in lines:
for part in line.split(";"):
for s in part.strip().split(","):
keyword = s.strip()
if keyword:
keywords.add(keyword)
if limit and len(keywords) >= limit:
return list(keywords)

return list(keywords)
2 changes: 1 addition & 1 deletion dbgpt/serve/rag/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(

def __rewrite_index_store_type(self, index_store_type):
# Rewrite Knowledge Graph Type
if CFG.GRAPH_COMMUNITY_SUMMARY_ENABLED:
if CFG.ENABLE_GRAPH_COMMUNITY_SUMMARY:
if index_store_type == "KnowledgeGraph":
return "CommunitySummaryKnowledgeGraph"
return index_store_type
Expand Down
84 changes: 21 additions & 63 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Graph store base class."""

import logging
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple
from typing import Optional

from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
from dbgpt.storage.graph_store.graph import Direction, Graph

logger = logging.getLogger(__name__)

Expand All @@ -23,78 +23,36 @@ class GraphStoreConfig(BaseModel):
default=None,
description="The embedding function of graph store, optional.",
)
summary_enabled: bool = Field(
enable_summary: bool = Field(
default=False,
description="Enable graph community summary or not.",
)
enable_document_graph: bool = Field(
default=True,
description="Enable document graph search or not.",
)
enable_triplet_graph: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)


class GraphStoreBase(ABC):
"""Graph store base class."""

def __init__(self, config: GraphStoreConfig):
"""Initialize graph store."""
self._config = config
self._conn = None

@abstractmethod
def get_config(self) -> GraphStoreConfig:
"""Get the graph store config."""

@abstractmethod
def get_vertex_type(self) -> str:
"""Get the vertex type."""

@abstractmethod
def get_edge_type(self) -> str:
"""Get the edge type."""

@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Add triplet."""

@abstractmethod
def insert_graph(self, graph: Graph):
"""Add graph."""

@abstractmethod
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""

@abstractmethod
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplet."""
def _escape_quotes(self, text: str) -> str:
"""Escape single and double quotes in a string for queries."""

@abstractmethod
def truncate(self):
"""Truncate Graph."""

@abstractmethod
def drop(self):
"""Drop graph."""

@abstractmethod
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""

@abstractmethod
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""

@abstractmethod
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore on graph."""

@abstractmethod
def query(self, query: str, **args) -> Graph:
"""Execute a query."""

def aquery(self, query: str, **args) -> Graph:
"""Async execute a query."""
return self.query(query, **args)

@abstractmethod
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""
# @abstractmethod
# def _paser(self, entities: List[Vertex]) -> str:
# """Parse entities to string."""
1 change: 1 addition & 0 deletions dbgpt/storage/graph_store/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Graph store factory."""

import logging
from typing import Tuple, Type

Expand Down
Loading

0 comments on commit 88e3d12

Please sign in to comment.