diff --git a/bot/retrievers/custom_retriever.py b/bot/retrievers/custom_retriever.py index bb05ec0..46147ac 100644 --- a/bot/retrievers/custom_retriever.py +++ b/bot/retrievers/custom_retriever.py @@ -10,6 +10,7 @@ ) from llama_index.core.schema import Node, NodeWithScore, ObjectType from llama_index.core.vector_stores.types import VectorStoreQueryResult +from utils.globals import RETRIEVER_THRESHOLD class CustomVectorStoreRetriever(VectorIndexRetriever): @@ -50,10 +51,12 @@ def _build_node_list_from_query_result( score: float | None = None if query_result.similarities is not None: score = query_result.similarities[ind] - # This is the part we updated - node_new = Node.from_dict(node.to_dict()) - node_with_score = NodeWithScore(node=node_new, score=score) - node_with_scores.append(node_with_score) + if score is not None and score >= RETRIEVER_THRESHOLD: + # This is the part we updated + node_new = Node.from_dict(node.to_dict()) + node_with_score = NodeWithScore(node=node_new, score=score) + + node_with_scores.append(node_with_score) return node_with_scores diff --git a/subquery.py b/subquery.py index f34b8c2..f68ac3d 100644 --- a/subquery.py +++ b/subquery.py @@ -5,7 +5,9 @@ from llama_index.core.tools import QueryEngineTool, ToolMetadata from llama_index.llms.openai import OpenAI from llama_index.question_gen.guidance import GuidanceQuestionGenerator +from tc_hivemind_backend.db.utils.preprocess_text import BasePreprocessor from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from utils.globals import INVALID_QUERY_RESPONSE from utils.qdrant_utils import QDrantUtils from utils.query_engine import ( DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, @@ -198,6 +200,10 @@ def query_multiple_source( metadata=tool_metadata, ) ) + if not BasePreprocessor().extract_main_content(text=query): + response = INVALID_QUERY_RESPONSE + source_nodes = [] + return response, source_nodes embed_model = CohereEmbedding() llm = OpenAI("gpt-4o-mini") diff --git a/utils/globals.py b/utils/globals.py new file mode 100644 index 0000000..aa558f7 --- /dev/null +++ b/utils/globals.py @@ -0,0 +1,7 @@ +# the theshold to skip nodes of being included in an answer +RETRIEVER_THRESHOLD = 0.4 +REFERENCE_SCORE_THRESHOLD = 0.5 +INVALID_QUERY_RESPONSE = ( + "We're unable to process your query. Please refine it and try again." +) +QUERY_ERROR_MESSAGE = "Sorry, we're unable to process your question at the moment. Please try again later." diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index ebcb250..04ca54b 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -11,6 +11,7 @@ from llama_index.llms.openai import OpenAI from schema.type import DataType from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess +from utils.globals import RETRIEVER_THRESHOLD from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils qa_prompt = PromptTemplate( @@ -176,15 +177,19 @@ def _setup_vector_store_index( def _process_basic_query(self, query_str: str) -> Response: nodes: list[NodeWithScore] = self.retriever.retrieve(query_str) - context_str = "\n\n".join([n.node.get_content() for n in nodes]) + nodes_filtered = [node for node in nodes if node.score >= RETRIEVER_THRESHOLD] + context_str = "\n\n".join([n.node.get_content() for n in nodes_filtered]) prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(prompt) # return final_response - return Response(response=str(response), source_nodes=nodes) + return Response(response=str(response), source_nodes=nodes_filtered) def _process_summary_query(self, query_str: str) -> Response: summary_nodes = self.summary_retriever.retrieve(query_str) + summary_nodes_filtered = [ + node for node in summary_nodes if node.score >= RETRIEVER_THRESHOLD + ] utils = QdrantEngineUtils( metadata_date_key=self.metadata_date_key, metadata_date_format=self.metadata_date_format, @@ -193,7 +198,7 @@ def _process_summary_query(self, query_str: str) -> Response: dates = [ node.metadata[self.metadata_date_summary_key] - for node in summary_nodes + for node in summary_nodes_filtered if self.metadata_date_summary_key in node.metadata ] @@ -208,8 +213,14 @@ def _process_summary_query(self, query_str: str) -> Response: ) raw_nodes = retriever.retrieve(query_str) - context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) + raw_nodes_filtered = [ + node for node in raw_nodes if node.score >= RETRIEVER_THRESHOLD + ] + + context_str = utils.combine_nodes_for_prompt( + summary_nodes_filtered, raw_nodes_filtered + ) prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(prompt) - return Response(response=str(response), source_nodes=raw_nodes) + return Response(response=str(response), source_nodes=raw_nodes_filtered) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 837e083..fa7a2b3 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -14,6 +14,7 @@ from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import NodeWithScore from llama_index.llms.openai import OpenAI +from utils.globals import RETRIEVER_THRESHOLD from utils.query_engine.base_pg_engine import BasePGEngine from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils @@ -46,13 +47,18 @@ def custom_query(self, query_str: str): similar_nodes = retriever.query_db( query=query_str, filters=self._filters, date_interval=self._d ) + similar_nodes_filtered = [ + node for node in similar_nodes if node.score >= RETRIEVER_THRESHOLD + ] - context_str = self._prepare_context_str(similar_nodes, summary_nodes=None) + context_str = self._prepare_context_str( + similar_nodes_filtered, summary_nodes=None + ) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") - return Response(response=str(response), source_nodes=similar_nodes) + return Response(response=str(response), source_nodes=similar_nodes_filtered) @classmethod def prepare_platform_engine( diff --git a/utils/query_engine/prepare_answer_sources.py b/utils/query_engine/prepare_answer_sources.py index a0091ad..5e443cf 100644 --- a/utils/query_engine/prepare_answer_sources.py +++ b/utils/query_engine/prepare_answer_sources.py @@ -1,17 +1,20 @@ import logging from llama_index.core.schema import NodeWithScore +from utils.globals import REFERENCE_SCORE_THRESHOLD class PrepareAnswerSources: - def __init__(self, threshold: float = 0.5, max_refs_per_source: int = 3) -> None: + def __init__( + self, threshold: float = REFERENCE_SCORE_THRESHOLD, max_refs_per_source: int = 3 + ) -> None: """ Initialize the PrepareAnswerSources class. Parameters ---------- threshold : float, optional - Minimum score threshold for including a node's URL, by default 0.5 + Minimum score threshold for including a node's URL, by default 0.5 set in globals file max_refs_per_source : int, optional Maximum number of references to include per data source, by default 3 """ diff --git a/worker/tasks.py b/worker/tasks.py index 346f542..47d5fd2 100644 --- a/worker/tasks.py +++ b/worker/tasks.py @@ -5,6 +5,7 @@ from llama_index.core.schema import NodeWithScore from subquery import query_multiple_source from utils.data_source_selector import DataSourceSelector +from utils.globals import QUERY_ERROR_MESSAGE from utils.query_engine.prepare_answer_sources import PrepareAnswerSources from utils.traceloop import init_tracing from worker.celery import app @@ -21,7 +22,7 @@ def ask_question_auto_search( ) answer_sources = PrepareAnswerSources().prepare_answer_sources(nodes=references) except Exception: - response = "Sorry, We cannot process your question at the moment." + response = QUERY_ERROR_MESSAGE answer_sources = None logging.error( f"Errors raised while processing the question for community: {community_id}!"