From a107d6286e06b5b98db9605c2d2b304881dbea33 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 22 Dec 2023 23:00:22 +0100 Subject: [PATCH] Update retriever parameters and prompting strategy --- app/configuration/default.py | 6 ++++-- app/main.py | 4 ++-- ragger_duck/prompt/_api.py | 9 +++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/app/configuration/default.py b/app/configuration/default.py index 7d13e6c..f9c8a72 100644 --- a/app/configuration/default.py +++ b/app/configuration/default.py @@ -1,10 +1,12 @@ # Retriever parameters SEMANTIC_RETRIEVER_PATH = "../models/api_semantic_retrieval.joblib" +SEMANTIC_TOP_K = 5 LEXICAL_RETRIEVER_PATH = "../models/api_lexical_retrieval.joblib" +LEXICAL_TOP_K = 5 CROSS_ENCODER_PATH = "cross-encoder/ms-marco-MiniLM-L-6-v2" CROSS_ENCODER_THRESHOLD = 2.0 -CROSS_ENCODER_MIN_TOP_K = 1 -CROSS_ENCODER_MAX_TOP_K = 5 +CROSS_ENCODER_MIN_TOP_K = 5 +CROSS_ENCODER_MAX_TOP_K = 10 # Device parameters DEVICE = "mps" diff --git a/app/main.py b/app/main.py index e07cc7d..f03ddbb 100644 --- a/app/main.py +++ b/app/main.py @@ -51,8 +51,8 @@ async def startup_event(): cross_encoder = CrossEncoder(model_name=conf.CROSS_ENCODER_PATH, device=conf.DEVICE) api_retriever = RetrieverReranker( cross_encoder=cross_encoder, - semantic_retriever=api_semantic_retriever, - lexical_retriever=api_lexical_retriever, + semantic_retriever=api_semantic_retriever.set_params(top_k=conf.SEMANTIC_TOP_K), + lexical_retriever=api_lexical_retriever.set_params(top_k=conf.LEXICAL_TOP_K), threshold=conf.CROSS_ENCODER_THRESHOLD, min_top_k=conf.CROSS_ENCODER_MIN_TOP_K, max_top_k=conf.CROSS_ENCODER_MAX_TOP_K, diff --git a/ragger_duck/prompt/_api.py b/ragger_duck/prompt/_api.py index d75626b..33ea74c 100644 --- a/ragger_duck/prompt/_api.py +++ b/ragger_duck/prompt/_api.py @@ -46,12 +46,13 @@ def __call__(self, query, **prompt_kwargs): signature_retriever = inspect.signature(self.api_retriever.query) if "lexical_query" in signature_retriever.parameters: logger.info( - f"Retriever {self.api_retriever.__class__.__name} supports lexical" + f"Retriever {self.api_retriever.__class__.__name__} supports lexical" " queries" ) prompt = ( - "[INST] Summarize the query provided by extracting keywords from it. " - "Only list the keywords only separated by a comma. \n" + "[INST] Rephrase the query to have correct wording in a context of " + "machine-learning. Make sure to have the right spelling. Finally, only " + "provide a list of keywords separated by commas.\n\n" f"query: {query}[/INST]" ) @@ -81,7 +82,7 @@ def __call__(self, query, **prompt_kwargs): "machine-learning question.\n\n" "Answer to the query below using the additional provided content. " "The additional content is composed of the HTML link to the source and the " - "extracted text to be used.\n\n" + "extracted contextual information.\n\n" "Be succinct.\n\n" f"query: {query}\n\n" f"{context_query} [/INST]."