Skip to content

Commit

Permalink
optimize apis and CommunityStore
Browse files Browse the repository at this point in the history
  • Loading branch information
fanzhidongyzby committed Aug 15, 2024
1 parent 7229734 commit f430bd1
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 43 deletions.
2 changes: 1 addition & 1 deletion dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self) -> None:

# Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "False")
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", False)
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
Expand Down
14 changes: 13 additions & 1 deletion dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Graph store base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Generator

from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
Expand All @@ -23,11 +23,19 @@ class GraphStoreConfig(BaseModel):
default=None,
description="The embedding function of graph store, optional.",
)
summary_enabled: bool = Field(
default=False,
description="Enable graph community summary or not.",
)


class GraphStoreBase(ABC):
"""Graph store base class."""

@abstractmethod
def get_config(self) -> GraphStoreConfig:
"""Get the graph store config."""

@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Add triplet."""
Expand Down Expand Up @@ -66,3 +74,7 @@ def explore(
@abstractmethod
def query(self, query: str, **args) -> Graph:
"""Execute a query."""

@abstractmethod
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""
78 changes: 43 additions & 35 deletions dbgpt/storage/graph_store/community_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import MemoryGraph, Vertex
from dbgpt.storage.graph_store.graph import Vertex, Graph

client = OpenAI(api_key="")

Expand Down Expand Up @@ -61,32 +61,39 @@ def __del__(self):
self.connector.close()


class Community:
id: str
level: str
data: Graph
summary: str = None


class CommunityStore:
def __init__(self, graph_store: GraphStoreBase, enable_persistence: bool = True):
# Initialize with a graph store and maximum hierarchical level for Leiden algorithm
self._graph_store = graph_store
self._max_hierarchical_level = 3
self._community_summary = {}
self._max_hierarchy_level = 3
self._enable_persistence = enable_persistence
self._orm = SQLiteORM() if enable_persistence else None
self._executor = ThreadPoolExecutor(max_workers=10)

async def build_communities(self):
# Build hierarchical communities using the Leiden algorithm
LEIDEN_QUERY = "" # TODO: create leiden query in TuGraph
community_hierarchical_clusters = self._graph_store.stream_query(LEIDEN_QUERY)
community_info = await self._retrieve_community_info(
community_hierarchical_clusters
)
await self._summarize_communities(community_info)

async def _retrieve_community_info(
self, clusters: Generator[MemoryGraph, None, None]
) -> Dict[str, List[str]]:
# discover communities
graph_name = self._graph_store.get_config().name
query = f"CALL {graph_name}.leiden({self._max_hierarchy_level})"
communities_metadata = self._graph_store.stream_query(query)

# summarize communities
communities = await self._retrieve_communities(communities_metadata)
await self._summarize_communities(communities)

async def _retrieve_communities(
self, communities_metadata: Generator[Graph, None, None]
) -> List[Community]:
"""Collect detailed information for each node based on their community.
# community_hierarchical_clusters structure: Generator[MemoryGraph, None, None]
Each MemoryGraph contains:
# community_hierarchical_clusters structure: Generator[Graph, None, None]
Each Graph contains:
vertices: A set of Vertex objects, each representing a node in the graph.
edges: A set of Edge objects, each representing an edge in the graph.
Vertex objects may include the following attributes:
Expand All @@ -113,10 +120,10 @@ async def _retrieve_community_info(
"""

community_info: Dict[str, List[str]] = {}
community_info: List[Community] = []
tasks = []

for memory_graph in clusters:
for memory_graph in communities_metadata:
for vertex in memory_graph.vertices:
task = asyncio.create_task(
self._process_vertex(memory_graph, vertex, community_info)
Expand All @@ -134,51 +141,52 @@ async def _retrieve_community_info(

async def _process_vertex(
self,
memory_graph: MemoryGraph,
memory_graph: Graph,
vertex: Vertex,
community_info: Dict[str, List[str]],
communities: List[Community],
):
cluster_id = vertex.properties.get("community_id", "unknown")
if cluster_id not in community_info:
community_info[cluster_id] = []
if cluster_id not in communities:
communities[cluster_id] = []

for edge in memory_graph.edges:
if edge.src_id == vertex.id:
neighbor_vertex = memory_graph.get_vertex(edge.dst_id)
if neighbor_vertex:
detail = f"{vertex.id} -> {neighbor_vertex.id} -> {edge.label} -> {edge.properties.get('description', 'No description')}"
community_info[cluster_id].append(detail)
communities[cluster_id].append(detail)

async def _summarize_communities(self, community_info: Dict[str, List[str]]):
async def _summarize_communities(self, communities: List[Community]):
"""Generate and store summaries for each community."""
tasks = []
for community_id, details in community_info.items():
task = asyncio.create_task(self._summarize_community(community_id, details))
for community in communities:
task = asyncio.create_task(self._summarize_community(community))
tasks.append(task)
await asyncio.gather(*tasks)

async def _summarize_community(self, community_id: str, details: List[str]):
details_text = f"{' '.join(details)}."
summary = await self._generate_community_summary(details_text)
self._community_summary[community_id] = summary
async def _summarize_community(self, community: Community):
summary = await self._generate_community_summary(community.data)
community.summary = summary

if self._enable_persistence and self._orm:
await asyncio.get_event_loop().run_in_executor(
self._executor,
self._orm.update_community_summary,
community_id,
community.id,
summary,
)

async def summarize_communities(self) -> Dict[str, str]:
async def search_communities(self, query: str) -> List[Community]:
# TODO: search communities relevant with query (by RDB / Vector / index)
if self._enable_persistence and self._orm:
return await asyncio.get_event_loop().run_in_executor(
self._executor, self._orm.fetch_all_communities
)
else:
return self._community_summary
# TODO: in-memory search cache can be used here
return

async def _generate_community_summary(self, text):
async def _generate_community_summary(self, graph: Graph):
"""Generate summary for a given text using an LLM."""
response = client.chat.completions.create(
model="gpt-4o",
Expand All @@ -195,7 +203,7 @@ async def _generate_community_summary(self, text):
"""
),
},
{"role": "user", "content": text},
{"role": "user", "content": graph},
],
)
return response.choices[0].message.content
11 changes: 10 additions & 1 deletion dbgpt/storage/graph_store/memgraph_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Graph store base class."""
import json
import logging
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Generator

from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
Expand All @@ -26,9 +26,14 @@ class MemoryGraphStore(GraphStoreBase):

def __init__(self, graph_store_config: MemoryGraphStoreConfig):
"""Initialize MemoryGraphStore with a memory graph."""
self._graph_store_config = graph_store_config
self._edge_name_key = graph_store_config.edge_name_key
self._graph = MemoryGraph(edge_label=self._edge_name_key)

def get_config(self):
"""Get the graph store config."""
return self._graph_store_config

def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert a triplet into the graph."""
self._graph.append_edge(Edge(sub, obj, **{self._edge_name_key: rel}))
Expand Down Expand Up @@ -79,3 +84,7 @@ def explore(
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
raise NotImplementedError("Query memory graph not allowed")

def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""
raise NotImplementedError("Stream query memory graph not allowed")
9 changes: 8 additions & 1 deletion dbgpt/storage/graph_store/neo4j_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Neo4j vector store."""
import logging
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Generator

from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
Expand All @@ -24,6 +24,9 @@ def __init__(self, graph_store_config: Neo4jStoreConfig):
"""Initialize the Neo4jStore with connection details."""
pass

def get_config(self):
"""Get the graph store config."""

def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert triplets."""
pass
Expand Down Expand Up @@ -62,3 +65,7 @@ def explore(
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
return MemoryGraph()

def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""

9 changes: 5 additions & 4 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@ class TuGraphStoreConfig(GraphStoreConfig):
default="label",
description="The label of edge name, `label` by default.",
)
summary_enabled: bool = Field(
default=False,
description=""
)


class TuGraphStore(GraphStoreBase):
"""TuGraph graph store."""

def __init__(self, config: TuGraphStoreConfig) -> None:
"""Initialize the TuGraphStore with connection details."""
self._config = config
self._host = os.getenv("TUGRAPH_HOST", "127.0.0.1") or config.host
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
Expand Down Expand Up @@ -120,6 +117,10 @@ def _create_schema(self):
"{self._node_label}"]]', ["id",STRING,false],["description",STRING,true])"""
self.conn.run(create_edge_gql)

def get_config(self):
"""Get the graph store config."""
return self._config

def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
query = (
Expand Down

0 comments on commit f430bd1

Please sign in to comment.