From ef0fcfece934e3d0e29a1ebc182783409524880f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 15 Dec 2023 21:38:43 +0100 Subject: [PATCH] add config for training --- scripts/configuration.py | 7 +++++++ scripts/train_retriever.py | 12 +++++------- 2 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 scripts/configuration.py diff --git a/scripts/configuration.py b/scripts/configuration.py new file mode 100644 index 0000000..5ca1e09 --- /dev/null +++ b/scripts/configuration.py @@ -0,0 +1,7 @@ +DEVICE = "cuda" +# DEVICE = "mps" + +API_DOC_PATH = ( + "/Users/glemaitre/Documents/packages/scikit-learn/doc/_build/html/stable/" + "modules/generated" +) diff --git a/scripts/train_retriever.py b/scripts/train_retriever.py index 4bb6fe5..fe9c30b 100644 --- a/scripts/train_retriever.py +++ b/scripts/train_retriever.py @@ -10,13 +10,11 @@ import sys from pathlib import Path +import configuration as config import joblib sys.path.append(str(Path(__file__).parent.parent)) -API_DOC = Path( - "/Users/glemaitre/Documents/packages/scikit-learn/doc/_build/html/stable/" - "modules/generated" -) +API_DOC = Path(config.API_DOC_PATH) # %% [markdown] # Define the training pipeline that extract the text chunks from the API documentation @@ -29,7 +27,9 @@ from rag_based_llm.retrieval import SemanticRetriever from rag_based_llm.scraping import APIDocExtractor -embedding = SentenceTransformer(model_name_or_path="thenlper/gte-large", device="mps") +embedding = SentenceTransformer( + model_name_or_path="thenlper/gte-large", device=config.DEVICE +) pipeline = Pipeline( steps=[ ("extractor", APIDocExtractor(chunk_size=700, chunk_overlap=50, n_jobs=-1)), @@ -51,10 +51,8 @@ # %% from sklearn.feature_extraction.text import CountVectorizer -from sklearn.pipeline import Pipeline from rag_based_llm.retrieval import BM25Retriever -from rag_based_llm.scraping import APIDocExtractor count_vectorizer = CountVectorizer(ngram_range=(1, 5)) pipeline = Pipeline(