Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 28, 2023
1 parent 1f5507d commit 8ea1b2f
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 101 deletions.
12 changes: 8 additions & 4 deletions app/configuration/default.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Retriever parameters
API_SEMANTIC_RETRIEVER_PATH = "../models/user_guide_semantic_retrieval.joblib"
API_SEMANTIC_RETRIEVER_PATH = "../models/api_semantic_retrieval.joblib"
API_SEMANTIC_TOP_K = 5
API_LEXICAL_RETRIEVER_PATH = "../models/user_guide_lexical_retrieval.joblib"
API_LEXICAL_RETRIEVER_PATH = "../models/api_lexical_retrieval.joblib"
API_LEXICAL_TOP_K = 5
USER_GUIDE_SEMANTIC_RETRIEVER_PATH = "../models/user_guide_semantic_retrieval.joblib"
USER_GUIDE_SEMANTIC_TOP_K = 5
USER_GUIDE_LEXICAL_RETRIEVER_PATH = "../models/user_guide_lexical_retrieval.joblib"
USER_GUIDE_LEXICAL_TOP_K = 5
CROSS_ENCODER_PATH = "cross-encoder/ms-marco-MiniLM-L-6-v2"
CROSS_ENCODER_THRESHOLD = 2.0
CROSS_ENCODER_MIN_TOP_K = 5
CROSS_ENCODER_MAX_TOP_K = 10
CROSS_ENCODER_MIN_TOP_K = 3
CROSS_ENCODER_MAX_TOP_K = 20

# Device parameters
DEVICE = "mps"
Expand Down
24 changes: 15 additions & 9 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import resources as res
from schemas import WSMessage

from ragger_duck.prompt import APIPromptingStrategy
from ragger_duck.prompt import CombinePromptingStrategy
from ragger_duck.retrieval import RetrieverReranker

DEFAULT_PORT = 8123
Expand Down Expand Up @@ -48,15 +48,21 @@ async def startup_event():

api_semantic_retriever = joblib.load(conf.API_SEMANTIC_RETRIEVER_PATH)
api_lexical_retriever = joblib.load(conf.API_LEXICAL_RETRIEVER_PATH)
user_guide_semantic_retriever = joblib.load(conf.API_SEMANTIC_RETRIEVER_PATH)
user_guide_lexical_retriever = joblib.load(conf.API_LEXICAL_RETRIEVER_PATH)
cross_encoder = CrossEncoder(model_name=conf.CROSS_ENCODER_PATH, device=conf.DEVICE)
api_retriever = RetrieverReranker(
retriever = RetrieverReranker(
retrievers=[
api_semantic_retriever.set_params(top_k=conf.API_SEMANTIC_TOP_K),
api_lexical_retriever.set_params(top_k=conf.API_LEXICAL_TOP_K),
user_guide_semantic_retriever.set_params(
top_k=conf.USER_GUIDE_SEMANTIC_TOP_K
),
user_guide_lexical_retriever.set_params(
top_k=conf.USER_GUIDE_LEXICAL_TOP_K
),
],
cross_encoder=cross_encoder,
semantic_retriever=api_semantic_retriever.set_params(
top_k=conf.API_SEMANTIC_TOP_K
),
lexical_retriever=api_lexical_retriever.set_params(
top_k=conf.API_LEXICAL_TOP_K
),
threshold=conf.CROSS_ENCODER_THRESHOLD,
min_top_k=conf.CROSS_ENCODER_MIN_TOP_K,
max_top_k=conf.CROSS_ENCODER_MAX_TOP_K,
Expand All @@ -69,7 +75,7 @@ async def startup_event():
n_threads=conf.N_THREADS,
n_ctx=conf.CONTEXT_TOKENS,
)
agent = APIPromptingStrategy(llm=llm, api_retriever=api_retriever)
agent = CombinePromptingStrategy(llm=llm, retriever=retriever)
logging.info("Server started")


Expand Down
3 changes: 2 additions & 1 deletion ragger_duck/prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._api import APIPromptingStrategy
from ._merge import CombinePromptingStrategy

__all__ = ["APIPromptingStrategy"]
__all__ = ["APIPromptingStrategy", "CombinePromptingStrategy"]
62 changes: 62 additions & 0 deletions ragger_duck/prompt/_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import logging

from sklearn.base import BaseEstimator
from sklearn.utils._param_validation import HasMethods

logger = logging.getLogger(__name__)


class CombinePromptingStrategy(BaseEstimator):
"""Prompting strategy for answering a query.
We use the following prompting strategy:
- If the retriever support a lexical query, we first extract the keywords from
the query and use them specifically for the lexical search. Use the full query
for the semantic search.
- If the retriever does not support a lexical query, we use the full query as-is.
Once we retrieve the API-related context, we request to answer the query using the
context.
Parameters
----------
llm : llm instance
The language model to use for the prompting. We expect the model to implement
a `__call__` method that takes a prompt and returns a response. It should be an
"Instruct"-based model.
retriever : retriever instance
The retriever to use for retrieving the context. We expect the retriever to
implement a `query` method.
"""

_parameter_constraints = {
"llm": [HasMethods(["__call__"])],
"retriever": [HasMethods(["query"])],
}

def __init__(self, *, llm, retriever):
self.llm = llm
self.retriever = retriever

def __call__(self, query, **prompt_kwargs):
logger.info(f"Query: {query}")
context = self.retriever.query(query=query)
sources = set([info["source"] for info in context])
context_query = "\n".join(
f"source: {info['source']}\ncontent: {info['text']}\n" for info in context
)

prompt = (
"[INST] You are a scikit-learn expert that should be able to answer "
"machine-learning question.\n\n"
"Answer to the query below using the additional provided content. "
"The additional content is composed of the HTML link to the source and the "
"extracted contextual information.\n\n"
"Be succinct.\n\n"
f"query: {query}\n\n"
f"{context_query} [/INST]."
)
logger.info(f"The final prompt is:\n{prompt}")
return self.llm(prompt, **prompt_kwargs), sources
54 changes: 13 additions & 41 deletions ragger_duck/retrieval/_reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
class RetrieverReranker(BaseEstimator):
"""Hybrid retriever (lexical and semantic) followed by a cross-encoder reranker.
We can accept several retrievers in case you want to rerank the results of
several retrievers.
Parameters
----------
semantic_retriever : semantic retriever or None
Semantic retriever used to retrieve the context.
lexical_retriever : lexical retriever or None
Lexical retriever used to retrieve the context.
retrievers : list of retriever instances
The retrievers to use for retrieving the context. We expect the retrievers to
implement a `query` method.
cross_encoder : :obj:`sentence_transformers.CrossEncoder`
Cross-encoder used to rerank the results of the hybrid retriever.
Expand All @@ -36,8 +37,7 @@ class RetrieverReranker(BaseEstimator):
"""

_parameter_constraints = {
"semantic_retriever": [HasMethods(["fit", "query"]), None],
"lexical_retriever": [HasMethods(["fit", "query"]), None],
"retrievers": [list],
"cross_encoder": [HasMethods(["predict"])],
"min_top_k": [Interval(Integral, left=0, right=None, closed="left"), None],
"max_top_k": [Interval(Integral, left=0, right=None, closed="left"), None],
Expand All @@ -48,16 +48,14 @@ class RetrieverReranker(BaseEstimator):
def __init__(
self,
*,
semantic_retriever,
lexical_retriever,
retrievers,
cross_encoder,
min_top_k=None,
max_top_k=None,
threshold=None,
drop_duplicates=True,
):
self.semantic_retriever = semantic_retriever
self.lexical_retriever = lexical_retriever
self.retrievers = retrievers
self.cross_encoder = cross_encoder
self.min_top_k = min_top_k
self.max_top_k = max_top_k
Expand Down Expand Up @@ -89,49 +87,23 @@ def _get_context(search):
return search["text"]
return search

def query(
self,
query,
*,
lexical_query=None,
semantic_query=None,
):
def query(self, query):
"""Retrieve the most relevant documents for the query.
Parameters
----------
query : str
The user query.
lexical_query : str, default=None
A specific query to retrieve the context of the lexical search. If None,
`query` is used.
semantic_query : str, default=None
A specific query to retrieve the context of the semantic search. If None,
`query` is used.
Returns
-------
list of str or dict
The list of the most relevant document from the training set.
"""
if lexical_query is None:
lexical_query = query
if semantic_query is None:
semantic_query = query

if self.lexical_retriever is not None:
lexical_search = self.lexical_retriever.query(lexical_query)
else:
lexical_search = []

if self.semantic_retriever is not None:
semantic_search = self.semantic_retriever.query(semantic_query)
else:
semantic_search = []
unranked_search = []
for retriever in self.retrievers:
unranked_search += retriever.query(query)

unranked_search = lexical_search + semantic_search
if not unranked_search:
return []

Expand Down
48 changes: 2 additions & 46 deletions ragger_duck/retrieval/tests/test_reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def test_retriever_reranker(input_texts, params, n_documents):
model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
cross_encoder = CrossEncoder(model_name=model_name)
retriever_reranker = RetrieverReranker(
retrievers=[bm25, faiss],
cross_encoder=cross_encoder,
semantic_retriever=faiss,
lexical_retriever=bm25,
drop_duplicates=False,
**params,
)
Expand All @@ -62,48 +61,6 @@ def test_retriever_reranker(input_texts, params, n_documents):
assert len(context) == n_documents


@pytest.mark.parametrize(
"search_strategy", ["lexical_retriever", "semantic_retriever", None]
)
def test_retriever_reranker_single_search(search_strategy):
"""Check that we can use a single search strategy."""
input_texts = ["xxx", "yyy", "zzz", "aaa"]

if search_strategy == "lexical_retriever":
bm25 = BM25Retriever(top_k=10).fit(input_texts)
faiss = None
elif search_strategy == "semantic_retriever":
bm25 = None
cache_folder_path = (
Path(__file__).parent.parent.parent / "embedding" / "tests" / "data"
)
model_name_or_path = "sentence-transformers/paraphrase-albert-small-v2"
embedder = SentenceTransformer(
model_name_or_path=model_name_or_path,
cache_folder=str(cache_folder_path),
show_progress_bar=False,
)
faiss = SemanticRetriever(embedding=embedder, top_k=10).fit(input_texts)
else:
bm25 = None
faiss = None

model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
cross_encoder = CrossEncoder(model_name=model_name)
retriever_reranker = RetrieverReranker(
cross_encoder=cross_encoder,
semantic_retriever=faiss,
lexical_retriever=bm25,
max_top_k=2,
)
retriever_reranker.fit()
context = retriever_reranker.query("xxx")
if search_strategy is None:
assert not context
else:
assert len(context) == 2


@pytest.mark.parametrize(
"input_texts",
[
Expand Down Expand Up @@ -138,9 +95,8 @@ def test_retriever_reranker_drop_duplicate(
model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
cross_encoder = CrossEncoder(model_name=model_name)
retriever_reranker = RetrieverReranker(
retrievers=[bm25, faiss],
cross_encoder=cross_encoder,
semantic_retriever=faiss,
lexical_retriever=bm25,
max_top_k=None, # we don't limit the number of retrieved documents
drop_duplicates=drop_duplicates,
)
Expand Down

0 comments on commit 8ea1b2f

Please sign in to comment.