Skip to content

Commit

Permalink
add config for training
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Dec 15, 2023
1 parent 3fc3cde commit ef0fcfe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
7 changes: 7 additions & 0 deletions scripts/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
DEVICE = "cuda"
# DEVICE = "mps"

API_DOC_PATH = (
"/Users/glemaitre/Documents/packages/scikit-learn/doc/_build/html/stable/"
"modules/generated"
)
12 changes: 5 additions & 7 deletions scripts/train_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
import sys
from pathlib import Path

import configuration as config
import joblib

sys.path.append(str(Path(__file__).parent.parent))
API_DOC = Path(
"/Users/glemaitre/Documents/packages/scikit-learn/doc/_build/html/stable/"
"modules/generated"
)
API_DOC = Path(config.API_DOC_PATH)

# %% [markdown]
# Define the training pipeline that extract the text chunks from the API documentation
Expand All @@ -29,7 +27,9 @@
from rag_based_llm.retrieval import SemanticRetriever
from rag_based_llm.scraping import APIDocExtractor

embedding = SentenceTransformer(model_name_or_path="thenlper/gte-large", device="mps")
embedding = SentenceTransformer(
model_name_or_path="thenlper/gte-large", device=config.DEVICE
)
pipeline = Pipeline(
steps=[
("extractor", APIDocExtractor(chunk_size=700, chunk_overlap=50, n_jobs=-1)),
Expand All @@ -51,10 +51,8 @@

# %%
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.pipeline import Pipeline

from rag_based_llm.retrieval import BM25Retriever
from rag_based_llm.scraping import APIDocExtractor

count_vectorizer = CountVectorizer(ngram_range=(1, 5))
pipeline = Pipeline(
Expand Down

0 comments on commit ef0fcfe

Please sign in to comment.