Skip to content

Commit

Permalink
style:fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt committed Aug 30, 2024
1 parent 03a40ca commit ab0b741
Showing 1 changed file with 13 additions and 22 deletions.
35 changes: 13 additions & 22 deletions examples/rag/graph_rag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import pytest

from dbgpt.configs.model_config import ROOT_PATH
from dbgpt.core import ModelMessage, HumanPromptTemplate, ModelRequest, Chunk
from dbgpt.core import Chunk, HumanPromptTemplate, ModelMessage, ModelRequest
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
from dbgpt.rag import ChunkParameters
from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.rag.retriever import RetrieverStrategy
from dbgpt.storage.knowledge_graph.community_summary import \
CommunitySummaryKnowledgeGraph, CommunitySummaryKnowledgeGraphConfig
from dbgpt.storage.knowledge_graph.community_summary import (
CommunitySummaryKnowledgeGraph,
CommunitySummaryKnowledgeGraphConfig,
)
from dbgpt.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraph,
BuiltinKnowledgeGraphConfig,
Expand Down Expand Up @@ -44,7 +46,7 @@ async def test_naive_graph_rag():
knowledge_file="examples/test_files/graphrag-mini.md",
chunk_strategy="CHUNK_BY_SIZE",
knowledge_graph=__create_naive_kg_connector(),
question="What's the relationship between TuGraph and DB-GPT ?"
question="What's the relationship between TuGraph and DB-GPT ?",
)


Expand All @@ -54,7 +56,7 @@ async def test_community_graph_rag():
knowledge_file="examples/test_files/graphrag-mini.md",
chunk_strategy="CHUNK_BY_MARKDOWN_HEADER",
knowledge_graph=__create_community_kg_connector(),
question="What's the relationship between TuGraph and DB-GPT ?"
question="What's the relationship between TuGraph and DB-GPT ?",
)


Expand All @@ -66,7 +68,7 @@ def __create_naive_kg_connector():
embedding_fn=None,
llm_client=llm_client,
model_name=model_name,
graph_store_type='MemoryGraph'
graph_store_type="MemoryGraph",
),
)

Expand All @@ -79,21 +81,17 @@ def __create_community_kg_connector():
embedding_fn=DefaultEmbeddingFactory.openai(),
llm_client=llm_client,
model_name=model_name,
graph_store_type='TuGraphGraph'
graph_store_type="TuGraphGraph",
),
)


async def ask_chunk(chunk: Chunk, question) -> str:
rag_template = (
"Based on the following [Context] {context}, "
"answer [Question] {question}."
"Based on the following [Context] {context}, " "answer [Question] {question}."
)
template = HumanPromptTemplate.from_template(rag_template)
messages = template.format_messages(
context=chunk.content,
question=question
)
messages = template.format_messages(context=chunk.content, question=question)
model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=model_name, messages=model_messages)
response = await llm_client.generate(request=request)
Expand All @@ -106,12 +104,7 @@ async def ask_chunk(chunk: Chunk, question) -> str:
return response.text


async def __run_graph_rag(
knowledge_file,
chunk_strategy,
knowledge_graph,
question
):
async def __run_graph_rag(knowledge_file, chunk_strategy, knowledge_graph, question):
file_path = os.path.join(ROOT_PATH, knowledge_file).format()
knowledge = KnowledgeFactory.from_file_path(file_path)
try:
Expand All @@ -128,9 +121,7 @@ async def __run_graph_rag(

# get embeddings retriever
retriever = assembler.as_retriever(1)
chunks = await retriever.aretrieve_with_scores(
question, score_threshold=0.3
)
chunks = await retriever.aretrieve_with_scores(question, score_threshold=0.3)

# chat
print(f"{await ask_chunk(chunks[0], question)}")
Expand Down

0 comments on commit ab0b741

Please sign in to comment.