This project implements a simplified version of SPLADE (Sparse Lexical and Expansion Model), leveraging sparse representations for efficient retrieval using masked language models (MLMs), such as xlm-roberta-base
. The project is designed for training, sparse vectorization, and retrieval tasks.
- Training: Train the SPLADE model using masked language models with a ranking loss.
- Sparse Vectorization: Convert text inputs into sparse representations (word-value pairs).
- Multilingual Support: Supports multilingual text with pre-trained models like
xlm-roberta-base
.
- Python 3.10+
- Poetry
- PyTorch
- Transformers (Hugging Face)
- Pandas
-
Clone the repository:
git clone https://github.com/marevol/simple-splade.git cd simple-splade
-
Install dependencies using Poetry:
poetry install
This will create a virtual environment and install all the necessary dependencies listed in
pyproject.toml
. -
Activate the virtual environment created by Poetry:
poetry shell
This project relies on the Amazon ESCI dataset for training the model. You need to download the dataset and place it in the correct directory.
-
Download the dataset:
- Download the shopping_queries_dataset_products.parquet and shopping_queries_dataset_examples.parquet files from the Amazon ESCI dataset.
-
Place the downloaded files in the
downloads
directory within your project folder:./downloads/shopping_queries_dataset_products.parquet ./downloads/shopping_queries_dataset_examples.parquet
-
The
main.py
script is set to load the dataset from thedownloads
directory by default. If you wish to place the files elsewhere, modify the paths in the script accordingly.
The main.py
script demonstrates how to use the Amazon ESCI dataset to train the SPLADE model, save it, and then use the trained model to convert text into sparse vectors for retrieval.
To run the sample execution with the Amazon ESCI dataset:
poetry run python main.py
This script performs the following steps:
- Training: It loads the product titles from the Amazon ESCI dataset, trains the SPLADE model on the titles, and saves the trained model.
- Sparse Vectorization: After training, the model is used to convert a sample text into a sparse vector representation.
- Retrieval: It demonstrates retrieval using the generated sparse vectors from dummy data.
You can modify the script or dataset paths as needed.
main.py
: The main entry point for running the sample execution using the Amazon ESCI dataset.simple_splade/vectorization.py
: Contains theSPLADESparseVectorizer
class for converting text into sparse vectors.simple_splade/model.py
: Defines theSimpleSPLADE
model architecture.simple_splade/train.py
: Handles the training process for the SPLADE model.simple_splade/evaluate.py
: Contains functions for evaluating the model using ranking loss.
Once the script completes, the following will happen:
- A trained model will be saved in the
splade_model
directory. - Sparse vector representations for the example text will be printed in the console.
- Retrieval results will be shown for a dummy query against indexed dummy documents.
This project is licensed under the Apache License, Version 2.0. See the LICENSE file for more details.