Skip to content

Commit

Permalink
fix: avoid duplicate references!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Jan 21, 2025
1 parent 6521f82 commit dbf0ddf
Showing 1 changed file with 39 additions and 37 deletions.
76 changes: 39 additions & 37 deletions utils/query_engine/prepare_answer_sources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections import defaultdict

from llama_index.core.schema import NodeWithScore
from utils.globals import REFERENCE_SCORE_THRESHOLD
Expand All @@ -23,12 +24,8 @@ def __init__(

def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str:
"""
Prepares a formatted string containing source URLs organized by tool name from the provided nodes.
This method processes a list of nodes, filtering them based on a score threshold and
organizing the URLs by their associated tool names. It creates a formatted output with
URLs numbered under their respective tool sections, limiting the number of references
per data source.
Prepares a formatted string containing unique source URLs organized by tool name
from the provided nodes, avoiding duplicate URLs.
Parameters
----------
Expand All @@ -42,7 +39,7 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str:
Returns
-------
str
all_sources : str
A formatted string containing numbered URLs organized by tool name, with the format:
References:
{tool_name}:
Expand All @@ -53,55 +50,60 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str:
- The input nodes list is empty
- No nodes meet the score threshold
- No valid URLs are found in the nodes' metadata
Notes
-----
- URLs are only included if their node's score meets or exceeds the threshold
(default: 0.5)
- Each tool's sources are grouped together and prefixed with the tool name
- URLs are numbered sequentially within each tool's section
- Maximum number of references per data source is limited by max_refs_per_source
(default: 3)
- References are selected based on highest scores when limiting
- Logs error messages when no nodes are available or when all nodes are below
the threshold
"""
if len(nodes) == 0:
logging.error("No reference nodes available! returning empty string.")
return ""

cleaned_nodes = [n for n in nodes if n is not None]

# link of places that we got the answer from
all_sources: str = "References:\n"
# Group nodes by tool name while filtering by score and valid URL
tool_sources = defaultdict(list)
for tool_nodes in cleaned_nodes:
# platform name
tool_name = tool_nodes.sub_q.tool_name
for node in tool_nodes.sources:
if (
node.score >= self.threshold
and node.metadata.get("url") is not None
):
tool_sources[tool_name].append(node)

if not tool_sources:
logging.error(
f"All node scores are below threshold ({self.threshold}). Returning empty string!"
)
return ""

all_sources = "References:\n"

# Filter and sort nodes by score
valid_nodes = [
node
for node in tool_nodes.sources
if node.score >= self.threshold and node.metadata.get("url") is not None
]
valid_nodes.sort(key=lambda x: x.score, reverse=True)
# Process each tool's nodes, remove duplicate URLs, sort and limit references
for tool_name, nodes_list in tool_sources.items():
unique_nodes = {}
for node in nodes_list:
url = node.metadata.get("url")
# If URL not seen yet or current node has a higher score, update it
if url not in unique_nodes or node.score > unique_nodes[url].score:
unique_nodes[url] = node

# Limit the number of references
limited_nodes = valid_nodes[: self.max_refs_per_source]
# Sort nodes by score in descending order and limit references
sorted_nodes = sorted(
unique_nodes.values(), key=lambda x: x.score, reverse=True
)
limited_nodes = sorted_nodes[: self.max_refs_per_source]

if limited_nodes:
urls = [node.metadata.get("url") for node in limited_nodes]
sources: list[str] = [
f"[{idx + 1}] {url}" for idx, url in enumerate(urls)
sources = [
f"[{idx + 1}] {node.metadata['url']}"
for idx, node in enumerate(limited_nodes)
]
sources_combined: str = "\n".join(sources)
sources_combined = "\n".join(sources)
all_sources += f"{tool_name}:\n{sources_combined}\n\n"

if all_sources == "References:\n":
logging.error(
f"All node scores are below threshold. threshold: {self.threshold}"
". returning empty string!"
f"All node scores are below threshold ({self.threshold}). Returning empty string!"
)
return ""

# Remove trailing newlines
return all_sources.removesuffix("\n\n")

0 comments on commit dbf0ddf

Please sign in to comment.