Skip to content

Commit

Permalink
Update retriever parameters and prompting strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 22, 2023
1 parent 47b846e commit a107d62
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
6 changes: 4 additions & 2 deletions app/configuration/default.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Retriever parameters
SEMANTIC_RETRIEVER_PATH = "../models/api_semantic_retrieval.joblib"
SEMANTIC_TOP_K = 5
LEXICAL_RETRIEVER_PATH = "../models/api_lexical_retrieval.joblib"
LEXICAL_TOP_K = 5
CROSS_ENCODER_PATH = "cross-encoder/ms-marco-MiniLM-L-6-v2"
CROSS_ENCODER_THRESHOLD = 2.0
CROSS_ENCODER_MIN_TOP_K = 1
CROSS_ENCODER_MAX_TOP_K = 5
CROSS_ENCODER_MIN_TOP_K = 5
CROSS_ENCODER_MAX_TOP_K = 10

# Device parameters
DEVICE = "mps"
Expand Down
4 changes: 2 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ async def startup_event():
cross_encoder = CrossEncoder(model_name=conf.CROSS_ENCODER_PATH, device=conf.DEVICE)
api_retriever = RetrieverReranker(
cross_encoder=cross_encoder,
semantic_retriever=api_semantic_retriever,
lexical_retriever=api_lexical_retriever,
semantic_retriever=api_semantic_retriever.set_params(top_k=conf.SEMANTIC_TOP_K),
lexical_retriever=api_lexical_retriever.set_params(top_k=conf.LEXICAL_TOP_K),
threshold=conf.CROSS_ENCODER_THRESHOLD,
min_top_k=conf.CROSS_ENCODER_MIN_TOP_K,
max_top_k=conf.CROSS_ENCODER_MAX_TOP_K,
Expand Down
9 changes: 5 additions & 4 deletions ragger_duck/prompt/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def __call__(self, query, **prompt_kwargs):
signature_retriever = inspect.signature(self.api_retriever.query)
if "lexical_query" in signature_retriever.parameters:
logger.info(
f"Retriever {self.api_retriever.__class__.__name} supports lexical"
f"Retriever {self.api_retriever.__class__.__name__} supports lexical"
" queries"
)
prompt = (
"[INST] Summarize the query provided by extracting keywords from it. "
"Only list the keywords only separated by a comma. \n"
"[INST] Rephrase the query to have correct wording in a context of "
"machine-learning. Make sure to have the right spelling. Finally, only "
"provide a list of keywords separated by commas.\n\n"
f"query: {query}[/INST]"
)

Expand Down Expand Up @@ -81,7 +82,7 @@ def __call__(self, query, **prompt_kwargs):
"machine-learning question.\n\n"
"Answer to the query below using the additional provided content. "
"The additional content is composed of the HTML link to the source and the "
"extracted text to be used.\n\n"
"extracted contextual information.\n\n"
"Be succinct.\n\n"
f"query: {query}\n\n"
f"{context_query} [/INST]."
Expand Down

0 comments on commit a107d62

Please sign in to comment.