From db8e5472e12a3b9a6297caed4f7e5639e5d6a39a Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 21 Jan 2025 11:09:52 +0330 Subject: [PATCH] fix: overriding the llama-index function to improve error handling! --- utils/query_engine/subquestion_engine.py | 37 +++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/utils/query_engine/subquestion_engine.py b/utils/query_engine/subquestion_engine.py index 58416df..cbcfb43 100644 --- a/utils/query_engine/subquestion_engine.py +++ b/utils/query_engine/subquestion_engine.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional, Sequence, cast import llama_index.core.instrumentation as instrument @@ -7,13 +8,14 @@ from llama_index.core.callbacks.schema import CBEventType, EventPayload from llama_index.core.instrumentation.events.query import QueryEndEvent, QueryStartEvent from llama_index.core.query_engine import SubQuestionAnswerPair, SubQuestionQueryEngine -from llama_index.core.question_gen.types import BaseQuestionGenerator +from llama_index.core.question_gen.types import BaseQuestionGenerator, SubQuestion from llama_index.core.response_synthesizers import BaseSynthesizer from llama_index.core.schema import NodeWithScore, QueryBundle from llama_index.core.tools.query_engine import QueryEngineTool from llama_index.core.utils import get_color_mapping, print_text dispatcher = instrument.get_dispatcher(__name__) +logger = logging.getLogger(__name__) class CustomSubQuestionQueryEngine(SubQuestionQueryEngine): @@ -95,3 +97,36 @@ def query( QueryEndEvent(query=str_or_query_bundle, response=query_result) ) return query_result, qa_pairs_all + + def _query_subq( + self, sub_q: SubQuestion, color: Optional[str] = None + ) -> Optional[SubQuestionAnswerPair]: + try: + with self.callback_manager.event( + CBEventType.SUB_QUESTION, + payload={EventPayload.SUB_QUESTION: SubQuestionAnswerPair(sub_q=sub_q)}, + ) as event: + question = sub_q.sub_question + query_engine = self._query_engines[sub_q.tool_name] + + if self._verbose: + print_text(f"[{sub_q.tool_name}] Q: {question}\n", color=color) + + response = query_engine.query(question) + response_text = str(response) + + if self._verbose: + print_text(f"[{sub_q.tool_name}] A: {response_text}\n", color=color) + + qa_pair = SubQuestionAnswerPair( + sub_q=sub_q, answer=response_text, sources=response.source_nodes + ) + + event.on_end(payload={EventPayload.SUB_QUESTION: qa_pair}) + + return qa_pair + except Exception as exp: + logger.warning( + f"[{sub_q.tool_name}] Failed to run {sub_q.sub_question}: {exp}" + ) + return None