TextBox is developed based on Python and PyTorch for reproducing and developing text generation algorithms in a unified, comprehensive and efficient framework for research purpose. Our library includes 21 text generation algorithms, covering two major tasks:
- Unconditional (input-free) Generation
- Conditional (Seq2Seq) Generation, including Machine Translation, Text Summarization, Attribute-to-Text, and Dialogue Systems
We provide the support for 9 benchmark text generation datasets. A user can apply our library to process the original data copy, or simply download the processed datasets by our team.
View TextBox homepage for more information: https://github.com/RUCAIBox/TextBox.
View TextBox document for more implement details: https://textbox.readthedocs.io/en/latest/.
- Unified and modularized framework. TextBox is built upon PyTorch and designed to be highly modularized, by decoupling diverse models into a set of highly reusable modules.
- Comprehensive models, benchmark datasets and standardized evaluations. TextBox also contains a wide range of text generation models, covering the categories of VAE, GAN, RNN or Transformer based models, and pre-trained language models (PLM).
- Extensible and flexible framework. TextBox provides convenient interfaces of various common functions or modules in text generation models, RNN encoder-decoder, Transformer encoder-decoder and pre-trained language model.
- Easy and convenient to get started. TextBox provides flexible configuration files, which allows green hands to run experiments without modifying source code, and allows researchers to conduct qualitative analysis by modifying few configurations.
TextBox requires:
-
Python >= 3.6.2
-
torch >= 1.6.0
. Please follow the official instructions to install the appropriate version according to your CUDA version and NVIDIA driver version. -
GCC >= 5.1.0
pip install textbox
git clone https://github.com/RUCAIBox/TextBox.git && cd TextBox
pip install -e . --verbose
With the source code, you can use the provided script for initial usage of our library:
python run_textbox.py
This script will run the RNN model on the COCO dataset to conduct unconditional generation.
If you want to change the parameters, such as rnn_type
, max_vocab_size
, just set the additional command parameters as you need:
python run_textbox.py --rnn_type=lstm --max_vocab_size=4000
We also support to modify YAML configuration files in corresponding dataset and model properties
folders and include it in the command line.
If you want to change the model, the dataset or the task type, just run the script by modifying corresponding command parameters:
python run_textbox.py --model=[model_name] --dataset=[dataset_name]
model_name
is the model to be run, such as RNN and BART.
If TextBox is installed from pip, you can create a new python file, download the dataset, and write and run the following code:
from textbox.quick_start import run_textbox
run_textbox(config_dict={'model': 'RNN',
'dataset': 'COCO',
'data_path': './dataset'})
This will perform the training and test of the RNN model on the COCO dataset.
If you want to run different models, parameters or datasets, the operations are same with Start from source.
TextBox supports to apply part of pretrained language models (PLM) to conduct text generation. Take the GPT-2 for example, we will show you how to use PLMs to fine-tune.
-
Download the GPT-2 model provided from Hugging Face (https://huggingface.co/gpt2/tree/main), including
config.json
,merges.txt
,pytorch_model.bin
,tokenizer.json
andvocab.json
. Then put them in a folder at the same level astextbox
, such aspretrained_model/gpt2
. -
After downloading, you just need to run the command:
python run_textbox.py --model=GPT2 --dataset=COCO \
--pretrained_model_path=pretrained_model/gpt2
TextBox supports to train models with multiple GPUs conveniently. You don't need to modify the model, just run the following command:
python -m torch.distributed.launch --nproc_per_node=[gpu_num] \
run_textbox.py --model=[model_name] \
--dataset=[dataset_name] --gpu_id=[gpu_ids] --DDP=True
gpu_num
is the number of GPUs you want to train with (such as 4), and gpu_ids
is the usable GPU id list (such as 0,1,2,3).
Notice that: we only support DDP for end-to-end model. We will add support for non-end-to-end models, such as GAN, in the future.
TextBox is developed and maintained by AI Box.
TextBox uses MIT License.