Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid duplicate references! #122

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 39 additions & 37 deletions utils/query_engine/prepare_answer_sources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections import defaultdict

from llama_index.core.schema import NodeWithScore
from utils.globals import REFERENCE_SCORE_THRESHOLD
Expand All @@ -23,12 +24,8 @@ def __init__(

def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str:
"""
Prepares a formatted string containing source URLs organized by tool name from the provided nodes.
This method processes a list of nodes, filtering them based on a score threshold and
organizing the URLs by their associated tool names. It creates a formatted output with
URLs numbered under their respective tool sections, limiting the number of references
per data source.
Prepares a formatted string containing unique source URLs organized by tool name
from the provided nodes, avoiding duplicate URLs.
Parameters
----------
Expand All @@ -42,7 +39,7 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str:
Returns
-------
str
all_sources : str
A formatted string containing numbered URLs organized by tool name, with the format:
References:
{tool_name}:
Expand All @@ -53,55 +50,60 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore | None]) -> str:
- The input nodes list is empty
- No nodes meet the score threshold
- No valid URLs are found in the nodes' metadata
Notes
-----
- URLs are only included if their node's score meets or exceeds the threshold
(default: 0.5)
- Each tool's sources are grouped together and prefixed with the tool name
- URLs are numbered sequentially within each tool's section
- Maximum number of references per data source is limited by max_refs_per_source
(default: 3)
- References are selected based on highest scores when limiting
- Logs error messages when no nodes are available or when all nodes are below
the threshold
"""
if len(nodes) == 0:
logging.error("No reference nodes available! returning empty string.")
return ""

cleaned_nodes = [n for n in nodes if n is not None]

# link of places that we got the answer from
all_sources: str = "References:\n"
# Group nodes by tool name while filtering by score and valid URL
tool_sources = defaultdict(list)
for tool_nodes in cleaned_nodes:
# platform name
tool_name = tool_nodes.sub_q.tool_name
for node in tool_nodes.sources:
if (
node.score >= self.threshold
and node.metadata.get("url") is not None
):
tool_sources[tool_name].append(node)

if not tool_sources:
logging.error(
f"All node scores are below threshold ({self.threshold}). Returning empty string!"
)
return ""

all_sources = "References:\n"

# Filter and sort nodes by score
valid_nodes = [
node
for node in tool_nodes.sources
if node.score >= self.threshold and node.metadata.get("url") is not None
]
valid_nodes.sort(key=lambda x: x.score, reverse=True)
# Process each tool's nodes, remove duplicate URLs, sort and limit references
for tool_name, nodes_list in tool_sources.items():
unique_nodes = {}
for node in nodes_list:
url = node.metadata.get("url")
# If URL not seen yet or current node has a higher score, update it
if url not in unique_nodes or node.score > unique_nodes[url].score:
unique_nodes[url] = node

# Limit the number of references
limited_nodes = valid_nodes[: self.max_refs_per_source]
# Sort nodes by score in descending order and limit references
sorted_nodes = sorted(
unique_nodes.values(), key=lambda x: x.score, reverse=True
)
limited_nodes = sorted_nodes[: self.max_refs_per_source]

if limited_nodes:
urls = [node.metadata.get("url") for node in limited_nodes]
sources: list[str] = [
f"[{idx + 1}] {url}" for idx, url in enumerate(urls)
sources = [
f"[{idx + 1}] {node.metadata['url']}"
for idx, node in enumerate(limited_nodes)
]
sources_combined: str = "\n".join(sources)
sources_combined = "\n".join(sources)
all_sources += f"{tool_name}:\n{sources_combined}\n\n"

if all_sources == "References:\n":
logging.error(
f"All node scores are below threshold. threshold: {self.threshold}"
". returning empty string!"
f"All node scores are below threshold ({self.threshold}). Returning empty string!"
)
return ""

# Remove trailing newlines
return all_sources.removesuffix("\n\n")
Loading