Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/prompt update #32

Merged
merged 13 commits into from
Feb 12, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,5 @@ cython_debug/
hivemind-bot-env/*
main.ipynb
.DS_Store

temp_test_run_data.json
38 changes: 34 additions & 4 deletions bot/retrievers/forum_summary_retriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from bot.retrievers.summary_retriever_base import BaseSummarySearch
from llama_index.embeddings import BaseEmbedding
from llama_index.schema import NodeWithScore
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding


Expand Down Expand Up @@ -53,15 +54,44 @@ def retreive_filtering(
"""
nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k)

filters = self.define_filters(
nodes=nodes,
metadata_group1_key=metadata_group1_key,
metadata_group2_key=metadata_group2_key,
metadata_date_key=metadata_date_key,
)

return filters

def define_filters(
self,
nodes: list[NodeWithScore],
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
) -> list[dict[str, str]]:
"""
define dictionary filters based on metadata of retrieved nodes

Parameters
----------
nodes : list[dict[llama_index.schema.NodeWithScore]]
a list of retrieved similar nodes to define filters based

Returns
---------
filters : list[dict[str, str]]
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
"""
filters: list[dict[str, str]] = []

for node in nodes:
# the filter made by given node
filter: dict[str, str] = {}
if node.metadata[metadata_group1_key]:
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
if node.metadata[metadata_group2_key]:
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
# date filter
filter[metadata_date_key] = node.metadata[metadata_date_key]

Expand Down
4 changes: 3 additions & 1 deletion bot/retrievers/process_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def process_dates(dates: list[str], d: int) -> list[str]:
Returns
----------
dates_modified : list[str]
days added to it
days added to it sorted ascending meaning
the first index is the lowest date
and the last is the biggest date
"""
dates_modified: list[str] = []
if dates != []:
Expand Down
86 changes: 71 additions & 15 deletions bot/retrievers/retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from datetime import datetime, timedelta

from dateutil import parser
from llama_index.embeddings import BaseEmbedding
from llama_index.schema import NodeWithScore
from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult
from llama_index.vector_stores.postgres import DBEmbeddingRow
from sqlalchemy import Date, and_, cast, or_, select, text
from sqlalchemy import Date, and_, cast, null, or_, select, text
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding


Expand All @@ -12,7 +15,7 @@ class RetrieveSimilarNodes:
def __init__(
self,
vector_store: PGVectorStore,
similarity_top_k: int,
similarity_top_k: int | None,
embed_model: BaseEmbedding = CohereEmbedding(),
) -> None:
"""Init params."""
Expand All @@ -21,7 +24,11 @@ def __init__(
self._similarity_top_k = similarity_top_k

def query_db(
self, query: str, filters: list[dict[str, str]] | None = None
self,
query: str,
filters: list[dict[str, str | dict | None]] | None = None,
date_interval: int = 0,
**kwargs
) -> list[NodeWithScore]:
"""
query database with given filters (similarity search is also done)
Expand All @@ -30,48 +37,97 @@ def query_db(
-------------
query : str
the user question
filters : list[dict[str, str]] | None
filters : list[dict[str, str | dict | None]] | None
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
if `None` then no filtering would be applied
the value can be a dictionary with one key of "ne" and a value
which means to do a not equal operator `!=`
if `None` then no filtering would be applied.
date_interval : int
the number of back and forth days of date
default is set to 0 meaning no days back or forward.
**kwargs
ignore_sort : bool
to ignore sort by vector similarity.
Note: This would completely disable the similarity search and
it would just return the results with no ordering.
default is `False`. If `True` the query will be ignored and no embedding of it would be fetched
"""
ignore_sort = kwargs.get("ignore_sort", False)
self._vector_store._initialize()
embedding = self._embed_model.get_text_embedding(text=query)

if not ignore_sort:
embedding = self._embed_model.get_text_embedding(text=query)
else:
embedding = None

stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
self._vector_store._table_class.embedding.cosine_distance(embedding).label(
"distance"
),
).order_by(text("distance asc"))
(
self._vector_store._table_class.embedding.cosine_distance(embedding)
if not ignore_sort
else null()
).label("distance"),
)

if not ignore_sort:
stmt = stmt.order_by(text("distance asc"))

if filters is not None and filters != []:
conditions = []
for condition in filters:
filters_and = []
for key, value in condition.items():
if key == "date":
date: datetime
if isinstance(value, str):
date = parser.parse(value)
else:
raise ValueError(
"the values for filtering dates must be string!"
)
date_back = (date - timedelta(days=date_interval)).strftime(
"%Y-%m-%d"
)
date_forward = (date + timedelta(days=date_interval)).strftime(
"%Y-%m-%d"
)

# Apply ::date cast when the key is 'date'
filter_condition = cast(
filter_condition_back = cast(
self._vector_store._table_class.metadata_.op("->>")(key),
Date,
) >= cast(date_back, Date)

filter_condition_forward = cast(
self._vector_store._table_class.metadata_.op("->>")(key),
Date,
) == cast(value, Date)
) <= cast(date_forward, Date)

filters_and.append(filter_condition_back)
filters_and.append(filter_condition_forward)
else:
filter_condition = (
self._vector_store._table_class.metadata_.op("->>")(key)
== value
if not isinstance(value, dict)
else self._vector_store._table_class.metadata_.op("->>")(
key
)
!= value["ne"]
)

filters_and.append(filter_condition)
filters_and.append(filter_condition)

conditions.append(and_(*filters_and))

stmt = stmt.where(or_(*conditions))

stmt = stmt.limit(self._similarity_top_k)
if self._similarity_top_k is not None:
stmt = stmt.limit(self._similarity_top_k)

with self._vector_store._session() as session, session.begin():
res = session.execute(stmt)
Expand Down
64 changes: 56 additions & 8 deletions tests/unit/test_level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from unittest.mock import patch

from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever
from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes
from llama_index.schema import NodeWithScore, TextNode
from sqlalchemy.exc import OperationalError
from utils.query_engine.level_based_platform_query_engine import (
LevelBasedPlatformQueryEngine,
)
Expand All @@ -26,9 +29,9 @@ def test_prepare_platform_engine(self):
"""
# the output should always have a `date` key for each dictionary
filters = [
{"channel": "general", "date": "2023-01-02"},
{"thread": "discussion", "date": "2024-01-03"},
{"date": "2022-01-01"},
{"channel": "general", "thread": "some_thread", "date": "2023-01-02"},
{"channel": "general", "thread": "discussion", "date": "2024-01-03"},
{"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"},
]

engine = LevelBasedPlatformQueryEngine.prepare_platform_engine(
Expand All @@ -39,26 +42,71 @@ def test_prepare_platform_engine(self):
)
self.assertIsNotNone(engine)

def test_prepare_engine_auto_filter(self):
def test_prepare_engine_auto_filter_raise_error(self):
"""
Test prepare_engine_auto_filter method with sample data
when an error was raised
"""
with patch.object(
ForumBasedSummaryRetriever, "retreive_filtering"
ForumBasedSummaryRetriever, "define_filters"
) as mock_retriever:
# the output should always have a `date` key for each dictionary
mock_retriever.return_value = [
{"channel": "general", "date": "2023-01-02"},
{"thread": "discussion", "date": "2024-01-03"},
{"date": "2022-01-01"},
{"channel": "general", "thread": "some_thread", "date": "2023-01-02"},
{"channel": "general", "thread": "discussion", "date": "2024-01-03"},
{"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"},
]

with self.assertRaises(OperationalError):
# no database with name of `test_community` is available
_ = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
platform_table_name=self.platform_table_name,
level1_key=self.level1_key,
level2_key=self.level2_key,
date_key=self.date_key,
)

def test_prepare_engine_auto_filter(self):
"""
Test prepare_engine_auto_filter method with sample data in normal condition
"""
with patch.object(RetrieveSimilarNodes, "query_db") as mock_query:
# the output should always have a `date` key for each dictionary
mock_query.return_value = [
NodeWithScore(
node=TextNode(
text="some summaries #1",
metadata={
"thread": "thread#1",
"channel": "channel#1",
"date": "2022-01-01",
},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="some summaries #2",
metadata={
"thread": "thread#3",
"channel": "channel#2",
"date": "2022-01-02",
},
),
score=0,
),
]

# no database with name of `test_community` is available
engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
platform_table_name=self.platform_table_name,
level1_key=self.level1_key,
level2_key=self.level2_key,
date_key=self.date_key,
include_summary_context=True,
)
self.assertIsNotNone(engine)
Loading
Loading