Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance the triplets extraction in the knowledge graph by the batch size #2091

Merged
merged 20 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
89db275
feat: Improve triplet extraction batch size and handling
Appointat Oct 23, 2024
f8e3ed1
feat: Improve triplet extraction batch size and handling
Appointat Oct 23, 2024
a57029e
refactor: Add batch_extract method to ExtractorBase
Appointat Oct 28, 2024
3fc7640
refactor: refactor: Add batch_extract method to GraphExtractor
Appointat Oct 28, 2024
fee90cc
refactor: Add batch_extract method to LLMExtractor
Appointat Oct 28, 2024
ccd2cdf
refactor: Refactor CommunitySummaryKnowledgeGraph batch extraction me…
Appointat Oct 28, 2024
3f65e49
refactor: Update knowledge graph extraction batch size
Appointat Oct 29, 2024
a253542
refactor: Update knowledge graph extraction batch size
Appointat Oct 29, 2024
c565600
Refactor batch extraction methods in GraphExtractor and LLMExtractor
Appointat Oct 29, 2024
a4e602e
Refactor knowledge graph extraction batch size and method in Communit…
Appointat Oct 29, 2024
7d4d7f4
refactor: Refactor batch extraction methods in GraphExtractor and LLM…
Appointat Oct 29, 2024
5aaa393
feat: Refactor knowledge graph extraction batch size and method in Tu…
Appointat Oct 29, 2024
e8b82db
refactor: Update knowledge graph extraction batch size and method in …
Appointat Oct 29, 2024
0b87218
Refactor method signature in TuGraphStoreAdapter
Appointat Oct 29, 2024
e6f6d33
Refactor markdown format in community_summary.py
Appointat Oct 29, 2024
1ff3184
fix: Refactor graph store configuration and enable/disable graph search
Appointat Oct 30, 2024
a8f9321
chore: format the code
Appointat Oct 30, 2024
7e3c3c7
fix: Refactor TuGraphStoreAdapter to improve graph retrieval logic
Appointat Oct 30, 2024
0c263bf
fix
Appointat Oct 30, 2024
f0216d7
Refactor markdown format in community_summary.py
Appointat Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks

KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Transformer base class."""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand Down Expand Up @@ -37,6 +38,15 @@ class ExtractorBase(TransformerBase, ABC):
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""

@abstractmethod
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract results from texts."""


class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""
98 changes: 80 additions & 18 deletions dbgpt/rag/transformer/graph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""GraphExtractor class."""

import asyncio
import logging
import re
from typing import List, Optional
from typing import Dict, List, Optional

from dbgpt.core import Chunk, LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
Expand All @@ -23,35 +24,96 @@ def __init__(
self._chunk_history = chunk_history

config = self._chunk_history.get_config()

self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Load similar chunks."""
# load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
context = "\n".join(history) if history else ""

try:
# extract with chunk history
return await super()._extract(text, context, limit)

finally:
# save chunk to history
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
"""Load chunk context."""
text_context_map: Dict[str, str] = {}

for text in texts:
# Load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]

# Save chunk to history
await self._chunk_history.aload_document_with_limit(
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
self._max_chunks_once_load,
self._max_threads,
)

# Save chunk context to map
context = "\n".join(history) if history else ""
text_context_map[text] = context
return text_context_map

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract graphs from text.

Suggestion: to extract triplets in batches, call `batch_extract`.
"""
# Load similar chunks
text_context_map = await self.aload_chunk_context([text])
context = text_context_map[text]

# Extract with chunk history
return await super()._extract(text, context, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[List[Graph]]:
"""Extract graphs from chunks in batches.

Returns list of graphs in same order as input texts (text <-> graphs).
"""
if batch_size < 1:
raise ValueError("batch_size >= 1")

# 1. Load chunk context
text_context_map = await self.aload_chunk_context(texts)

# Pre-allocate results list to maintain order
graphs_list: List[List[Graph]] = [None] * len(texts)
total_batches = (len(texts) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]

# 2. Create tasks with their original indices
extraction_tasks = [
(
idx,
self._extract(text, text_context_map[text], limit),
)
for idx, text in enumerate(batch_texts, start=start_idx)
]

# 3. Process extraction in parallel while keeping track of indices
batch_results = await asyncio.gather(
*(task for _, task in extraction_tasks)
)

# 4. Place results in the correct positions
for (idx, _), graphs in zip(extraction_tasks, batch_results):
graphs_list[idx] = graphs

assert all(x is not None for x in graphs_list), "All positions should be filled"
return graphs_list

def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
graph = MemoryGraph()
edge_count = 0
Expand Down
28 changes: 28 additions & 0 deletions dbgpt/rag/transformer/llm_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""TripletExtractor class."""

import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
Expand All @@ -22,6 +24,32 @@ async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLM."""
return await self._extract(text, None, limit)

async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if batch_size < 1:
raise ValueError("batch_size >= 1")

results = []

for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]

# Create tasks for current batch
extraction_tasks = [
self._extract(text, None, limit) for text in batch_texts
]

# Execute batch concurrently and wait for all to complete
batch_results = await asyncio.gather(*extraction_tasks)
results.extend(batch_results)

return results

async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/rag/transformer/triplet_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TripletExtractor class."""

import logging
import re
from typing import Any, List, Optional, Tuple
Expand All @@ -12,7 +13,7 @@
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"Avoid stopwords. The subject, predicate, object can not be none.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable graph community summary or not.",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable document graph search or not.",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)


class GraphStoreBase(ABC):
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
)
self._enable_document_graph = (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.document_graph_enabled
)
self._enable_triplet_graph = (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.triplet_graph_enabled
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
Expand Down
Loading
Loading