From d888f7d33909db0a4c1a2c71a31f7c891840373b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 28 Dec 2023 11:25:51 +0100 Subject: [PATCH] iter --- app/configuration/default.py | 4 +- doc/references/scraping.rst | 2 +- doc/user_guide.rst | 18 ++- ragger_duck/prompt/_api.py | 7 +- ragger_duck/scraping/__init__.py | 14 +- ragger_duck/scraping/_api_doc.py | 169 --------------------- ragger_duck/scraping/_shared.py | 43 ------ ragger_duck/scraping/_user_guide.py | 57 +++---- ragger_duck/scraping/tests/test_api_doc.py | 84 +--------- ragger_duck/scraping/tests/test_shared.py | 46 ++---- scripts/train_retrievers.py | 3 +- 11 files changed, 56 insertions(+), 391 deletions(-) diff --git a/app/configuration/default.py b/app/configuration/default.py index f9c8a72..ae6e100 100644 --- a/app/configuration/default.py +++ b/app/configuration/default.py @@ -1,7 +1,7 @@ # Retriever parameters -SEMANTIC_RETRIEVER_PATH = "../models/api_semantic_retrieval.joblib" +SEMANTIC_RETRIEVER_PATH = "../models/user_guide_semantic_retrieval.joblib" SEMANTIC_TOP_K = 5 -LEXICAL_RETRIEVER_PATH = "../models/api_lexical_retrieval.joblib" +LEXICAL_RETRIEVER_PATH = "../models/user_guide_lexical_retrieval.joblib" LEXICAL_TOP_K = 5 CROSS_ENCODER_PATH = "cross-encoder/ms-marco-MiniLM-L-6-v2" CROSS_ENCODER_THRESHOLD = 2.0 diff --git a/doc/references/scraping.rst b/doc/references/scraping.rst index c396832..abb9147 100644 --- a/doc/references/scraping.rst +++ b/doc/references/scraping.rst @@ -13,5 +13,5 @@ Scraping the documentation :toctree: generated/ :template: class.rst - APIDocExtractor APINumPyDocExtractor + UserGuideDocExtractor diff --git a/doc/user_guide.rst b/doc/user_guide.rst index c495456..3d1b647 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -12,10 +12,8 @@ Scraping The scraping module provides some simple estimator that extract meaningful documentation from the documentation website. -:class:`~ragger_duck.scraping.APIDocExtractor` is a scraper that loads the -HTML pages and extract the documentation from the main API section. One can -provide a `chunk_size` and `chunk_overlap` to further split the documentation -sections into smaller chunks. +API documentation +----------------- :class:`~ragger_duck.scraping.APINumPyDocExtractor` is a more advanced scraper that uses `numpydoc` and it scraper to extract the documentation. Indeed, the @@ -24,10 +22,14 @@ chunks of documentation from the parsed sections. While, we don't control for the chunk size, the chunks are build such that they contain information only of a specific parameter and always refer to the class or function. We hope that scraping in such way can remove ambiguity that could exist when building chunks -without any control. Since we rely on `numpydoc` for the parsing that expect -a specific formatting, then -:class:`~ragger_duck.scraping.APINumPyDocExtractor` is much faster than -:class:`~ragger_duck.scraping.APIDocExtractor`. +without any control. + +User Guide documentation +------------------------ + +:class:`~ragger_duck.scraping.UserGuideExtractor` is a scraper that extract +documentation from the user guide. It is a simple scraper that extract +text information from the webpage. Additionally, this text can be chunked. Retriever ========= diff --git a/ragger_duck/prompt/_api.py b/ragger_duck/prompt/_api.py index 33ea74c..be06247 100644 --- a/ragger_duck/prompt/_api.py +++ b/ragger_duck/prompt/_api.py @@ -50,16 +50,15 @@ def __call__(self, query, **prompt_kwargs): " queries" ) prompt = ( - "[INST] Rephrase the query to have correct wording in a context of " - "machine-learning. Make sure to have the right spelling. Finally, only " - "provide a list of keywords separated by commas.\n\n" + "[INST] Extract a list of keywords from the query below for a context" + " of machine-learning using scikit-learn.\n\n" f"query: {query}[/INST]" ) # do not create a stream generator local_prompt_kwargs = prompt_kwargs.copy() local_prompt_kwargs["stream"] = False - logger.info("Prompting to get keywords from the query") + logger.info(f"Prompting to get keywords from the query:\n{prompt}") response = self.llm(prompt, **local_prompt_kwargs) keywords = response["choices"][0]["text"].strip() logger.info(f"Keywords: {keywords}") diff --git a/ragger_duck/scraping/__init__.py b/ragger_duck/scraping/__init__.py index 0383f42..a94c444 100644 --- a/ragger_duck/scraping/__init__.py +++ b/ragger_duck/scraping/__init__.py @@ -3,20 +3,10 @@ website of scikit-learn. """ -from ._api_doc import ( - APIDocExtractor, - APINumPyDocExtractor, - extract_api_doc, - extract_api_doc_from_single_file, -) -from ._user_guide import ( - UserGuideDocExtractor, -) +from ._api_doc import APINumPyDocExtractor +from ._user_guide import UserGuideDocExtractor __all__ = [ - "extract_api_doc", - "extract_api_doc_from_single_file", - "APIDocExtractor", "APINumPyDocExtractor", "UserGuideDocExtractor", ] diff --git a/ragger_duck/scraping/_api_doc.py b/ragger_duck/scraping/_api_doc.py index e386a6f..d68c027 100644 --- a/ragger_duck/scraping/_api_doc.py +++ b/ragger_duck/scraping/_api_doc.py @@ -4,18 +4,9 @@ import inspect import re import warnings -from itertools import chain -from numbers import Integral -from pathlib import Path -from bs4 import BeautifulSoup -from joblib import Parallel, delayed -from langchain.text_splitter import RecursiveCharacterTextSplitter from numpydoc.docscrape import NumpyDocString from sklearn.base import BaseEstimator, TransformerMixin -from sklearn.utils._param_validation import Interval - -from ._shared import _chunk_document, _extract_text_from_section SKLEARN_API_URL = "https://scikit-learn.org/stable/modules/generated/" @@ -36,68 +27,6 @@ def _api_path_to_api_url(path): return SKLEARN_API_URL + path.name -def extract_api_doc_from_single_file(api_html_file): - """Extract the text from the API documentation. - - This function can process classes and functions. - - Parameters - ---------- - api_html_file : :class:`pathlib.Path` - The path to the HTML API documentation. - - Returns - ------- - str - The text extracted from the API documentation. - """ - if not isinstance(api_html_file, Path): - raise ValueError( - f"The API HTML file should be a pathlib.Path object. Got {api_html_file!r}." - ) - if api_html_file.suffix != ".html": - raise ValueError( - f"The file {api_html_file} is not an HTML file. Please provide an HTML " - "file." - ) - with open(api_html_file, "r") as file: - soup = BeautifulSoup(file, "html.parser") - api_section = soup.section - return { - "source": _api_path_to_api_url(api_html_file), - "text": _extract_text_from_section(api_section), - } - - -def extract_api_doc(api_doc_folder, *, n_jobs=None): - """Extract text from each HTML API file from a folder - - Parameters - ---------- - api_doc_folder : :class:`pathlib.Path` - The path to the API documentation folder. - - n_jobs : int, default=None - The number of jobs to run in parallel. If None, then the number of jobs is set - to the number of CPU cores. - - Returns - ------- - list - A list of dictionaries containing the source and text of the API - documentation. - """ - if not isinstance(api_doc_folder, Path): - raise ValueError( - "The API documentation folder should be a pathlib.Path object. Got " - f"{api_doc_folder!r}." - ) - return Parallel(n_jobs=n_jobs)( - delayed(extract_api_doc_from_single_file)(api_html_file) - for api_html_file in api_doc_folder.glob("*.html") - ) - - def _extract_function_doc_numpydoc(function, import_name, html_source): """Extract documentation from a function using `numpydoc`. @@ -240,104 +169,6 @@ def _extract_function_doc_numpydoc(function, import_name, html_source): return extracted_doc -class APIDocExtractor(BaseEstimator, TransformerMixin): - """Extract text from the API documentation. - - This function can process classes and functions. - - Parameters - ---------- - chunk_size : int or None, default=300 - The size of the chunks to split the text into. If None, the text is not chunked. - - chunk_overlap : int, default=50 - The overlap between two consecutive chunks. - - n_jobs : int, default=None - The number of jobs to run in parallel. If None, then the number of jobs is set - to the number of CPU cores. - - Attributes - ---------- - text_splitter_ : :class:`langchain.text_splitter.RecursiveCharacterTextSplitter` - The text splitter to use to chunk the document. If `chunk_size` is None, this - attribute is None. - """ - - _parameter_constraints = { - "chunk_size": [Interval(Integral, left=1, right=None, closed="left"), None], - "chunk_overlap": [Interval(Integral, left=0, right=None, closed="left")], - "n_jobs": [Integral, None], - } - - def __init__(self, *, chunk_size=300, chunk_overlap=50, n_jobs=None): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - self.n_jobs = n_jobs - - def fit(self, X=None, y=None): - """No-op operation, only validate parameters. - - Parameters - ---------- - X : None - This parameter is ignored. - - y : None - This parameter is ignored. - - Returns - ------- - self - The fitted estimator. - """ - self._validate_params() - if self.chunk_size is not None: - self.text_splitter_ = RecursiveCharacterTextSplitter( - separators=["\n\n", "\n", " ", ""], - chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, - length_function=len, - ) - else: - self.text_splitter_ = None - return self - - def transform(self, X): - """Extract text from the API documentation. - - Parameters - ---------- - X : :class:`pathlib.Path` - The path to the API documentation folder. - - Returns - ------- - output : list - A list of dictionaries containing the source and text of the API - documentation. - """ - if self.chunk_size is None: - output = extract_api_doc(X, n_jobs=self.n_jobs) - else: - output = list( - chain.from_iterable( - Parallel(n_jobs=self.n_jobs, return_as="generator")( - delayed(_chunk_document)(self.text_splitter_, document) - for document in extract_api_doc(X, n_jobs=self.n_jobs) - ) - ) - ) - if not output: - raise ValueError( - "No API documentation was extracted. Please check the input folder." - ) - return output - - def _more_tags(self): - return {"X_types": ["string"], "stateless": True} - - class APINumPyDocExtractor(BaseEstimator, TransformerMixin): """Extract text from the API documentation using `numpydoc`. diff --git a/ragger_duck/scraping/_shared.py b/ragger_duck/scraping/_shared.py index 63f30b6..f84725f 100644 --- a/ragger_duck/scraping/_shared.py +++ b/ragger_duck/scraping/_shared.py @@ -1,46 +1,3 @@ -import re - -from bs4 import NavigableString - - -def _extract_text_from_section(section): - """Extract the text from an HTML section. - - Parameters - ---------- - section : :class:`bs4.element.Tag` - The HTML section from which to extract the text. - - Returns - ------- - str or None - The text extracted from the section. Return None if the section is a - :class:`bs4.NavigableString`. - - Notes - ----- - This function was copied from: - https://github.com/ray-project/llm-applications/blob/main/rag/data.py - (under CC BY 4.0 license) - """ - if isinstance(section, NavigableString): - return None - texts = [] - for elem in section.children: - if isinstance(elem, NavigableString): - text = elem.strip() - else: - text = elem.get_text(" ") - # Remove line breaks within a paragraph - newline = re.compile(r"\n+") - text = newline.sub(" ", text) - # Remove the duplicated spaces on the fly - multiple_spaces = re.compile(r"\s+") - text = multiple_spaces.sub(" ", text) - texts.append(text) - return " ".join(texts).replace("ΒΆ", "\n") - - def _chunk_document(text_splitter, document): """Chunk a document into smaller pieces. diff --git a/ragger_duck/scraping/_user_guide.py b/ragger_duck/scraping/_user_guide.py index a02eeca..3a652cb 100644 --- a/ragger_duck/scraping/_user_guide.py +++ b/ragger_duck/scraping/_user_guide.py @@ -1,5 +1,6 @@ """Utilities to scrape User Guide documentation.""" import logging +import re from itertools import chain from numbers import Integral from pathlib import Path @@ -10,7 +11,7 @@ from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils._param_validation import Interval -from ._shared import _chunk_document, _extract_text_from_section +from ._shared import _chunk_document SKLEARN_USER_GUIDE_URL = "https://scikit-learn.org/stable/modules/" loogger = logging.getLogger(__name__) @@ -44,10 +45,8 @@ def extract_user_guide_doc_from_single_file(html_file): Returns ------- - list of dict - Extract all sections from the HTML file and store it in a list of - dictionaries containing the source and text of the User Guide. If there - is no section, an empty list is returned. + dict + A dictionary containing the source and text of the User Guide documentation. """ if not isinstance(html_file, Path): raise ValueError( @@ -61,19 +60,21 @@ def extract_user_guide_doc_from_single_file(html_file): with open(html_file, "r") as file: soup = BeautifulSoup(file, "html.parser") - all_sections = soup.find_all("section") - if all_sections is None: - return [] - return [ - { - "source": _user_guide_path_to_user_guide_url(html_file), - "text": _extract_text_from_section(section), - } - for section in all_sections - ] + text = soup.get_text("") + # Remove line breaks within a paragraph + newline = re.compile(r"\n\s*") + text = newline.sub(r"\n", text) + # Remove the duplicated spaces on the fly + multiple_spaces = re.compile(" +") + text = multiple_spaces.sub(" ", text) + + return { + "source": _user_guide_path_to_user_guide_url(html_file), + "text": text, + } -def _extract_user_guide_doc(user_guide_doc_folder, *, n_jobs=None): +def _extract_user_guide_doc(user_guide_doc_folder): """Extract text from each HTML User Guide files from a folder Parameters @@ -81,13 +82,9 @@ def _extract_user_guide_doc(user_guide_doc_folder, *, n_jobs=None): user_guide_doc_folder : :class:`pathlib.Path` The path to the User Guide documentation folder. - n_jobs : int, default=None - The number of jobs to run in parallel. If None, then the number of jobs is set - to the number of CPU cores. - Returns ------- - list + list of dict A list of dictionaries containing the source and text of the API documentation. """ @@ -96,16 +93,10 @@ def _extract_user_guide_doc(user_guide_doc_folder, *, n_jobs=None): "The User Guide documentation folder should be a pathlib.Path object. Got " f"{user_guide_doc_folder!r}." ) - output = [] - for html_file in user_guide_doc_folder.glob("*.html"): - texts = extract_user_guide_doc_from_single_file(html_file) - if texts: - loogger.info(f"Extracted {len(texts)} sections from {html_file.name}.") - for text in texts: - if text["text"] is None or text["text"] == "": - continue - output.append(text) - return output + return [ + extract_user_guide_doc_from_single_file(html_file) + for html_file in user_guide_doc_folder.glob("*.html") + ] class UserGuideDocExtractor(BaseEstimator, TransformerMixin): @@ -186,13 +177,13 @@ def transform(self, X): documentation. """ if self.chunk_size is None: - output = _extract_user_guide_doc(X, n_jobs=self.n_jobs) + output = _extract_user_guide_doc(X) else: output = list( chain.from_iterable( Parallel(n_jobs=self.n_jobs, return_as="generator")( delayed(_chunk_document)(self.text_splitter_, document) - for document in _extract_user_guide_doc(X, n_jobs=self.n_jobs) + for document in _extract_user_guide_doc(X) ) ) ) diff --git a/ragger_duck/scraping/tests/test_api_doc.py b/ragger_duck/scraping/tests/test_api_doc.py index 2e653a7..a1e10ea 100644 --- a/ragger_duck/scraping/tests/test_api_doc.py +++ b/ragger_duck/scraping/tests/test_api_doc.py @@ -3,14 +3,7 @@ import importlib from pathlib import Path -import pytest - -from ragger_duck.scraping import ( - APIDocExtractor, - APINumPyDocExtractor, - extract_api_doc, - extract_api_doc_from_single_file, -) +from ragger_duck.scraping import APINumPyDocExtractor from ragger_duck.scraping._api_doc import _extract_function_doc_numpydoc API_TEST_FOLDER = Path(__file__).parent / "data" / "api_doc" @@ -19,81 +12,6 @@ SKLEARN_API_URL = "https://scikit-learn.org/stable/modules/generated/" -def test_extract_api_doc_from_single_file_not_html(): - """Check that we raise an error if the provided file is not an HTML file.""" - path_file = Path(__file__) - err_msg = "is not an HTML file" - with pytest.raises(ValueError, match=err_msg): - extract_api_doc_from_single_file(path_file) - - -@pytest.mark.parametrize( - "extract_function", [extract_api_doc_from_single_file, extract_api_doc] -) -def test_input_not_path_from_pathlib(extract_function): - """Check that we raise an error if the input is not a pathlib.Path.""" - err_msg = "should be a pathlib.Path" - with pytest.raises(ValueError, match=err_msg): - extract_function("not a pathlib.Path") - - -@pytest.mark.parametrize("html_file", HTML_TEST_FILES) -def test_extract_api_doc_from_single_file(html_file): - """Check that we have some meaningful scraping results when parsing - a single HTML file. - """ - path_file = API_TEST_FOLDER / html_file - text = extract_api_doc_from_single_file(path_file) - assert isinstance(text, dict) - assert set(text.keys()) == {"source", "text"} - assert text["source"] == SKLEARN_API_URL + html_file - expected_strings = ["Parameters :", "Returns :"] - for string in expected_strings: - assert string in text["text"] - - -@pytest.mark.parametrize("n_jobs", [None, 1, 2]) -def test_extract_api_doc(n_jobs): - """Checking the the behaviour of the `extract_api_doc` function.""" - output = extract_api_doc(API_TEST_FOLDER, n_jobs=n_jobs) - assert isinstance(output, list) - - assert len(output) == 2 - assert all([isinstance(elt, dict) for elt in output]) - sources = sorted([file["source"] for file in output]) - expected_sources = sorted( - [SKLEARN_API_URL + html_file for html_file in HTML_TEST_FILES] - ) - assert sources == expected_sources - - -@pytest.mark.parametrize("n_jobs", [None, 1, 2]) -@pytest.mark.parametrize("chunk_size", [20, None]) -def test_api_doc_extractor(n_jobs, chunk_size): - """Check the APIDocExtractor class.""" - extractor = APIDocExtractor(chunk_size=chunk_size, chunk_overlap=0, n_jobs=n_jobs) - output_extractor = extractor.fit_transform(API_TEST_FOLDER) - possible_source = [SKLEARN_API_URL + html_file for html_file in HTML_TEST_FILES] - for output in output_extractor: - assert isinstance(output, dict) - assert set(output.keys()) == {"source", "text"} - assert isinstance(output["source"], str) - assert isinstance(output["text"], str) - if chunk_size is not None: - assert len(output["text"]) <= chunk_size - assert output["source"] in possible_source - - assert extractor._get_tags()["stateless"] - - -def test_api_doc_extractor_error_empty(): - """Check that we raise an error if the folder does not contain any HTML file.""" - path_folder = Path(__file__).parent - err_msg = "No API documentation was extracted. Please check the input folder." - with pytest.raises(ValueError, match=err_msg): - APIDocExtractor().fit_transform(path_folder) - - def test_api_numpydoc_extractor(): """Check the APINumPyDocExtractor class.""" extractor = APINumPyDocExtractor() diff --git a/ragger_duck/scraping/tests/test_shared.py b/ragger_duck/scraping/tests/test_shared.py index 78644f6..81ce34a 100644 --- a/ragger_duck/scraping/tests/test_shared.py +++ b/ragger_duck/scraping/tests/test_shared.py @@ -1,50 +1,26 @@ from pathlib import Path -from bs4 import BeautifulSoup from langchain.text_splitter import RecursiveCharacterTextSplitter -from ragger_duck.scraping._shared import ( - _chunk_document, - _extract_text_from_section, -) +from ragger_duck.scraping import UserGuideDocExtractor +from ragger_duck.scraping._shared import _chunk_document -TEST_HMTL_FILE = Path(__file__).parent / "data" / "user_guide_doc" / "calibration.html" - - -def test_extract_text_from_section(): - """Check the behavior of the `_extract_text_from_section` function.""" - with open(TEST_HMTL_FILE, "r") as file: - soup = BeautifulSoup(file, "html.parser") - sections = soup.section - for section in sections: - text = _extract_text_from_section(section) - assert text is None or isinstance(text, str) - - # FIXME: write more tests to check the exact behavior depending on tags. +TEST_DATA_PATH = Path(__file__).parent / "data" / "user_guide_doc" def test_chunk_document(): """Check the behavior of the `_chunk_document` function.""" + chunk_size = 100 text_splitter = RecursiveCharacterTextSplitter( separators=["\n\n", "\n", " "], - chunk_size=40, - chunk_overlap=30, + chunk_size=chunk_size, + chunk_overlap=0, length_function=len, ) - texts = [] - with open(TEST_HMTL_FILE, "r") as file: - soup = BeautifulSoup(file, "html.parser") - for section in soup.section: - text = _extract_text_from_section(section) - if text is not None and len(text) > 20: - texts.append( - { - "text": text, - "source": "https://some_source.com", - } - ) - chunks = _chunk_document(text_splitter, texts[1]) + extractor = UserGuideDocExtractor(chunk_size=None) + document_unchunked = extractor.fit_transform(TEST_DATA_PATH)[0] + chunks = _chunk_document(text_splitter, document_unchunked) for chunk in chunks: assert isinstance(chunk, dict) - assert len(chunk["text"]) <= 40 - assert chunk["source"] == "https://some_source.com" + assert len(chunk["text"]) <= chunk_size + assert chunk["source"] in document_unchunked["source"] diff --git a/scripts/train_retrievers.py b/scripts/train_retrievers.py index 939265a..f62c991 100644 --- a/scripts/train_retrievers.py +++ b/scripts/train_retrievers.py @@ -91,7 +91,7 @@ cache_folder=config.CACHE_PATH, device=config.DEVICE, ) -user_guide_scraper = UserGuideDocExtractor(chunk_size=500, chunk_overlap=100, n_jobs=-1) +user_guide_scraper = UserGuideDocExtractor(chunk_size=700, chunk_overlap=10, n_jobs=-1) pipeline = Pipeline( steps=[ ("extractor", user_guide_scraper), @@ -113,6 +113,7 @@ from ragger_duck.retrieval import BM25Retriever count_vectorizer = CountVectorizer(ngram_range=(1, 5)) +user_guide_scraper = UserGuideDocExtractor(chunk_size=700, chunk_overlap=10, n_jobs=-1) pipeline = Pipeline( steps=[ ("extractor", user_guide_scraper),