Skip to content

Commit

Permalink
fixed bug and todo;
Browse files Browse the repository at this point in the history
  • Loading branch information
KingSkyLi committed Aug 22, 2024
1 parent e24455e commit 2bf5782
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 29 deletions.
3 changes: 0 additions & 3 deletions dbgpt/datasource/conn_tugraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def __init__(self, driver, graph):
def create_graph(self, graph_name: str) -> None:
"""Create a new graph."""
# run the query to get vertex labels

if not re.match(r"^[a-zA-Z\u4e00-\u9fff]", graph_name):
raise ValueError("Graph name must start with a letter or Chinese character, and cannot begin with a number or special character.")

with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
Expand Down
1 change: 1 addition & 0 deletions dbgpt/storage/graph_store/community_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def __retrieve_community(

async def __summarize_community(self, graph: Graph):
"""Generate summary for a given graph using an LLM."""
# todo remove (chunk and doc) vertex
nodes = "\n".join(
[
f"- {v.vid}: {v.get_prop('description')}"
Expand Down
30 changes: 8 additions & 22 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def _create_schema(self):
f"['name',string,false],"
f"['_document_id',string,true],"
f"['_community_id',string,true],"
f"['_level_id',string,true],"
f"['_weight',double,true],"
f"['description',string,true])"
)
Expand Down Expand Up @@ -208,7 +207,6 @@ def parser(node_list):
'description': escape_quotes(node.get_prop('description')) or '',
'_document_id': '0',
'_community_id': '1',
'_level_id': '0',
'_weight': 10
})
node_query = f"""CALL db.upsertVertex("{self._node_label}", [{parser(node_list)}])"""
Expand Down Expand Up @@ -296,8 +294,7 @@ def _format_paths(paths):
formatted_path.append({
"id":nodes[i]._properties["id"],
"community_id":nodes[i]._properties.get("_community_id",""),
"description":nodes[i]._properties.get("description",""),
"level_id":nodes[i]._properties.get("_level_id","")
"description":nodes[i]._properties.get("description","")
})
if i < len(rels):
formatted_path.append({
Expand All @@ -318,14 +315,11 @@ def _format_query_data(data):
if isinstance(value, graph.Node):
node_id = value._properties["id"]
description = value._properties.get("description","")
level_id = value._properties.get("level_id","")
community_id = value._properties.get("community_id","")
if not any(existing_node.get("id") == node_id for existing_node in nodes_list):
nodes_list.append({
"id":node_id,
"description":description,
"level_id":level_id,
"community_id":community_id
"description":description
})
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
Expand All @@ -348,9 +342,7 @@ def _format_query_data(data):
if not any(existing_node.get("id") == formatted_path[i]['id'] for existing_node in nodes_list):
nodes_list.append({
"id":formatted_path[i]['id'],
"description":formatted_path[i]['description'],
"level_id":formatted_path[i]['level_id'],
"community_id":formatted_path[i]['community_id']
"description":formatted_path[i]['description']
})
if i + 2 < len(formatted_path):
src_id = formatted_path[i]['id']
Expand All @@ -367,12 +359,10 @@ def _format_query_data(data):
if not any(existing_node.get("id") == node_id for existing_node in nodes_list):
nodes_list.append({
"id":"json_node",
"description":value,
"level_id":"",
"community_id":""
"description":value
})

nodes = [Vertex(node['id'],description=node['description'],community_id=node['community_id'],level_id=node['level_id']) for node in nodes_list]
nodes = [Vertex(node['id'],description=node['description']) for node in nodes_list]
rels = [
Edge(edge["src_id"], edge["dst_id"], label=edge["prop_id"],description=edge["description"])
for edge in rels_list
Expand All @@ -397,9 +387,7 @@ def stream_query(self, query: str) -> Generator[MemoryGraph, None, None]:
if isinstance(value, graph.Node):
node_id = value._properties["id"]
description = value._properties["description"]
community_id = value._properties["_community_id"]
level_id = value._properties["_level_id"]
vertex = Vertex(node_id,description=description,community_id=community_id,level_id=level_id)
vertex = Vertex(node_id,description=description)
mg.upsert_vertex(vertex)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
Expand All @@ -416,17 +404,15 @@ def stream_query(self, query: str) -> Generator[MemoryGraph, None, None]:
for i in range(len(nodes)):
formatted_path.append({
"id":nodes[i]._properties["id"],
"_community_id":nodes[i]._properties["_community_id"],
"description":nodes[i]._properties["description"],
"_level_id":nodes[i]._properties["_level_id"],
"description":nodes[i]._properties["description"]
})
if i < len(rels):
formatted_path.append({
"id":rels[i]._properties["id"],
"description":rels[i]._properties["description"],
})
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(Vertex(formatted_path[i]['id'],description=formatted_path[i]['description'],community_id=formatted_path[i]['_community_id'],level_id=formatted_path[i]['_level_id']))
mg.upsert_vertex(Vertex(formatted_path[i]['id'],description=formatted_path[i]['description']))
if i + 2 < len(formatted_path):
mg.append_edge(Edge(formatted_path[i]['id'], formatted_path[i + 2]['id'], label = formatted_path[i + 1]['id'], description=formatted_path[i + 1]['description']))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def search(self, query: str) -> List[Community]:
query, self._topk, self._score_threshold
)
return [
Community(id=chunk.id, summary=chunk.content) for chunk in chunks
Community(id=chunk.chunk_id, summary=chunk.content) for chunk in chunks
]

async def save(self, communities: List[Community]):
Expand Down
10 changes: 7 additions & 3 deletions dbgpt/storage/knowledge_graph/community_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,13 @@ def get_config(self) -> BuiltinKnowledgeGraphConfig:

async def aload_document(self, chunks: List[Chunk]) -> List[str]:
# Load documents as chunks
memory_graph = MemoryGraph()
# todo add doc node
for chunk in chunks:
# Extract triplets from each chunk
triplets = await self._triplet_extractor.extract(chunk.content)
memory_graph = MemoryGraph()
# todo add chunk node
# todo add relation doc-chunk
for triplet in triplets:
# Insert each triplet into the graph store
# if triplet.get("type") == "triplet":
Expand All @@ -133,6 +136,7 @@ async def aload_document(self, chunks: List[Chunk]) -> List[str]:
desc = data.get("description")
vertex = Vertex(id,description=desc)
memory_graph.upsert_vertex(vertex)
# todo add relation chunk-vertex
elif triplet.get("type") == "edge":
data = triplet.get("data")
edge_data = data.get("triplet")
Expand All @@ -142,10 +146,10 @@ async def aload_document(self, chunks: List[Chunk]) -> List[str]:
label = edge_data[1]
edge = Edge(sid,tid,label=label,description = desc)
memory_graph.append_edge(edge)
self._graph_store.insert_graph(memory_graph)
logger.info(
f"load {len(triplets)} triplets from chunk {chunk.chunk_id}")
# Build communities after loading all triplets
self._graph_store.insert_graph(memory_graph)
await self._community_store.build_communities()
return [chunk.chunk_id for chunk in chunks]

Expand All @@ -164,7 +168,7 @@ async def asimilar_search_with_scores(

# Combine results, keeping original order and scores
combined_results = global_results + local_results

# Add a source field to distinguish between global and local results
for chunk in combined_results[: len(global_results)]:
chunk.metadata["source"] = "global"
Expand Down

0 comments on commit 2bf5782

Please sign in to comment.