Skip to content

Commit

Permalink
Merge pull request #110 from TogetherCrew/feat/109-improve-performance
Browse files Browse the repository at this point in the history
feat: Added node scoring limit to be 0.4!
  • Loading branch information
amindadgar authored Dec 17, 2024
2 parents c1b2e5d + ec93c8b commit cda40e9
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 14 deletions.
11 changes: 7 additions & 4 deletions bot/retrievers/custom_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from llama_index.core.schema import Node, NodeWithScore, ObjectType
from llama_index.core.vector_stores.types import VectorStoreQueryResult
from utils.globals import RETRIEVER_THRESHOLD


class CustomVectorStoreRetriever(VectorIndexRetriever):
Expand Down Expand Up @@ -50,10 +51,12 @@ def _build_node_list_from_query_result(
score: float | None = None
if query_result.similarities is not None:
score = query_result.similarities[ind]
# This is the part we updated
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=score)

node_with_scores.append(node_with_score)
if score is not None and score >= RETRIEVER_THRESHOLD:
# This is the part we updated
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=score)

node_with_scores.append(node_with_score)

return node_with_scores
6 changes: 6 additions & 0 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.llms.openai import OpenAI
from llama_index.question_gen.guidance import GuidanceQuestionGenerator
from tc_hivemind_backend.db.utils.preprocess_text import BasePreprocessor
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from utils.globals import INVALID_QUERY_RESPONSE
from utils.qdrant_utils import QDrantUtils
from utils.query_engine import (
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
Expand Down Expand Up @@ -198,6 +200,10 @@ def query_multiple_source(
metadata=tool_metadata,
)
)
if not BasePreprocessor().extract_main_content(text=query):
response = INVALID_QUERY_RESPONSE
source_nodes = []
return response, source_nodes

embed_model = CohereEmbedding()
llm = OpenAI("gpt-4o-mini")
Expand Down
7 changes: 7 additions & 0 deletions utils/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# the theshold to skip nodes of being included in an answer
RETRIEVER_THRESHOLD = 0.4
REFERENCE_SCORE_THRESHOLD = 0.5
INVALID_QUERY_RESPONSE = (
"We're unable to process your query. Please refine it and try again."
)
QUERY_ERROR_MESSAGE = "Sorry, we're unable to process your question at the moment. Please try again later."
21 changes: 16 additions & 5 deletions utils/query_engine/dual_qdrant_retrieval_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from llama_index.llms.openai import OpenAI
from schema.type import DataType
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
from utils.globals import RETRIEVER_THRESHOLD
from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils

qa_prompt = PromptTemplate(
Expand Down Expand Up @@ -176,15 +177,19 @@ def _setup_vector_store_index(

def _process_basic_query(self, query_str: str) -> Response:
nodes: list[NodeWithScore] = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
nodes_filtered = [node for node in nodes if node.score >= RETRIEVER_THRESHOLD]
context_str = "\n\n".join([n.node.get_content() for n in nodes_filtered])
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)

# return final_response
return Response(response=str(response), source_nodes=nodes)
return Response(response=str(response), source_nodes=nodes_filtered)

def _process_summary_query(self, query_str: str) -> Response:
summary_nodes = self.summary_retriever.retrieve(query_str)
summary_nodes_filtered = [
node for node in summary_nodes if node.score >= RETRIEVER_THRESHOLD
]
utils = QdrantEngineUtils(
metadata_date_key=self.metadata_date_key,
metadata_date_format=self.metadata_date_format,
Expand All @@ -193,7 +198,7 @@ def _process_summary_query(self, query_str: str) -> Response:

dates = [
node.metadata[self.metadata_date_summary_key]
for node in summary_nodes
for node in summary_nodes_filtered
if self.metadata_date_summary_key in node.metadata
]

Expand All @@ -208,8 +213,14 @@ def _process_summary_query(self, query_str: str) -> Response:
)
raw_nodes = retriever.retrieve(query_str)

context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)
raw_nodes_filtered = [
node for node in raw_nodes if node.score >= RETRIEVER_THRESHOLD
]

context_str = utils.combine_nodes_for_prompt(
summary_nodes_filtered, raw_nodes_filtered
)
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)

return Response(response=str(response), source_nodes=raw_nodes)
return Response(response=str(response), source_nodes=raw_nodes_filtered)
10 changes: 8 additions & 2 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.llms.openai import OpenAI
from utils.globals import RETRIEVER_THRESHOLD
from utils.query_engine.base_pg_engine import BasePGEngine
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils

Expand Down Expand Up @@ -46,13 +47,18 @@ def custom_query(self, query_str: str):
similar_nodes = retriever.query_db(
query=query_str, filters=self._filters, date_interval=self._d
)
similar_nodes_filtered = [
node for node in similar_nodes if node.score >= RETRIEVER_THRESHOLD
]

context_str = self._prepare_context_str(similar_nodes, summary_nodes=None)
context_str = self._prepare_context_str(
similar_nodes_filtered, summary_nodes=None
)
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(fmt_qa_prompt)
logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}")

return Response(response=str(response), source_nodes=similar_nodes)
return Response(response=str(response), source_nodes=similar_nodes_filtered)

@classmethod
def prepare_platform_engine(
Expand Down
7 changes: 5 additions & 2 deletions utils/query_engine/prepare_answer_sources.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import logging

from llama_index.core.schema import NodeWithScore
from utils.globals import REFERENCE_SCORE_THRESHOLD


class PrepareAnswerSources:
def __init__(self, threshold: float = 0.5, max_refs_per_source: int = 3) -> None:
def __init__(
self, threshold: float = REFERENCE_SCORE_THRESHOLD, max_refs_per_source: int = 3
) -> None:
"""
Initialize the PrepareAnswerSources class.
Parameters
----------
threshold : float, optional
Minimum score threshold for including a node's URL, by default 0.5
Minimum score threshold for including a node's URL, by default 0.5 set in globals file
max_refs_per_source : int, optional
Maximum number of references to include per data source, by default 3
"""
Expand Down
3 changes: 2 additions & 1 deletion worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from llama_index.core.schema import NodeWithScore
from subquery import query_multiple_source
from utils.data_source_selector import DataSourceSelector
from utils.globals import QUERY_ERROR_MESSAGE
from utils.query_engine.prepare_answer_sources import PrepareAnswerSources
from utils.traceloop import init_tracing
from worker.celery import app
Expand All @@ -21,7 +22,7 @@ def ask_question_auto_search(
)
answer_sources = PrepareAnswerSources().prepare_answer_sources(nodes=references)
except Exception:
response = "Sorry, We cannot process your question at the moment."
response = QUERY_ERROR_MESSAGE
answer_sources = None
logging.error(
f"Errors raised while processing the question for community: {community_id}!"
Expand Down

0 comments on commit cda40e9

Please sign in to comment.