From dbf0ddf81d6a7cdadde190147e7289e1f68418f0 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 21 Jan 2025 10:39:56 +0330 Subject: [PATCH] fix: avoid duplicate references! --- utils/query_engine/prepare_answer_sources.py | 76 ++++++++++---------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/utils/query_engine/prepare_answer_sources.py b/utils/query_engine/prepare_answer_sources.py index 85703b3..74cd670 100644 --- a/utils/query_engine/prepare_answer_sources.py +++ b/utils/query_engine/prepare_answer_sources.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from llama_index.core.schema import NodeWithScore from utils.globals import REFERENCE_SCORE_THRESHOLD @@ -23,12 +24,8 @@ def __init__( def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str: """ - Prepares a formatted string containing source URLs organized by tool name from the provided nodes. - - This method processes a list of nodes, filtering them based on a score threshold and - organizing the URLs by their associated tool names. It creates a formatted output with - URLs numbered under their respective tool sections, limiting the number of references - per data source. + Prepares a formatted string containing unique source URLs organized by tool name + from the provided nodes, avoiding duplicate URLs. Parameters ---------- @@ -42,7 +39,7 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str: Returns ------- - str + all_sources : str A formatted string containing numbered URLs organized by tool name, with the format: References: {tool_name}: @@ -53,18 +50,6 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str: - The input nodes list is empty - No nodes meet the score threshold - No valid URLs are found in the nodes' metadata - - Notes - ----- - - URLs are only included if their node's score meets or exceeds the threshold - (default: 0.5) - - Each tool's sources are grouped together and prefixed with the tool name - - URLs are numbered sequentially within each tool's section - - Maximum number of references per data source is limited by max_refs_per_source - (default: 3) - - References are selected based on highest scores when limiting - - Logs error messages when no nodes are available or when all nodes are below - the threshold """ if len(nodes) == 0: logging.error("No reference nodes available! returning empty string.") @@ -72,36 +57,53 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str: cleaned_nodes = [n for n in nodes if n is not None] - # link of places that we got the answer from - all_sources: str = "References:\n" + # Group nodes by tool name while filtering by score and valid URL + tool_sources = defaultdict(list) for tool_nodes in cleaned_nodes: - # platform name tool_name = tool_nodes.sub_q.tool_name + for node in tool_nodes.sources: + if ( + node.score >= self.threshold + and node.metadata.get("url") is not None + ): + tool_sources[tool_name].append(node) + + if not tool_sources: + logging.error( + f"All node scores are below threshold ({self.threshold}). Returning empty string!" + ) + return "" + + all_sources = "References:\n" - # Filter and sort nodes by score - valid_nodes = [ - node - for node in tool_nodes.sources - if node.score >= self.threshold and node.metadata.get("url") is not None - ] - valid_nodes.sort(key=lambda x: x.score, reverse=True) + # Process each tool's nodes, remove duplicate URLs, sort and limit references + for tool_name, nodes_list in tool_sources.items(): + unique_nodes = {} + for node in nodes_list: + url = node.metadata.get("url") + # If URL not seen yet or current node has a higher score, update it + if url not in unique_nodes or node.score > unique_nodes[url].score: + unique_nodes[url] = node - # Limit the number of references - limited_nodes = valid_nodes[: self.max_refs_per_source] + # Sort nodes by score in descending order and limit references + sorted_nodes = sorted( + unique_nodes.values(), key=lambda x: x.score, reverse=True + ) + limited_nodes = sorted_nodes[: self.max_refs_per_source] if limited_nodes: - urls = [node.metadata.get("url") for node in limited_nodes] - sources: list[str] = [ - f"[{idx + 1}] {url}" for idx, url in enumerate(urls) + sources = [ + f"[{idx + 1}] {node.metadata['url']}" + for idx, node in enumerate(limited_nodes) ] - sources_combined: str = "\n".join(sources) + sources_combined = "\n".join(sources) all_sources += f"{tool_name}:\n{sources_combined}\n\n" if all_sources == "References:\n": logging.error( - f"All node scores are below threshold. threshold: {self.threshold}" - ". returning empty string!" + f"All node scores are below threshold ({self.threshold}). Returning empty string!" ) return "" + # Remove trailing newlines return all_sources.removesuffix("\n\n")