forked from skypilot-org/skypilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhuggingface_glue_imdb_app.py
41 lines (35 loc) · 1.3 KB
/
huggingface_glue_imdb_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Huggingface.
Fine-tunes a pretrained BERT model on the IMDB dataset:
https://github.com/huggingface/transformers/tree/master/examples/pytorch/text-classification
The dataset is downloaded automatically by huggingface, and saved to
~/.cache/huggingface.
"""
import sky
with sky.Dag() as dag:
# The setup command. Will be run under the working directory.
# https://github.com/huggingface/transformers/tree/master/examples#important-note
setup = '\
(git clone https://github.com/huggingface/transformers/ || true) && \
cd transformers && pip3 install . && \
cd examples/pytorch/text-classification && \
pip3 install -r requirements.txt'
# The command to run. Will be run under the working directory.
# https://github.com/huggingface/transformers/tree/master/examples/pytorch/text-classification
run = 'cd transformers/examples/pytorch/text-classification && \
python3 run_glue.py \
--model_name_or_path bert-base-cased \
--dataset_name imdb \
--do_train \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--max_steps 50 \
--output_dir /tmp/imdb/ \
--fp16'
train = sky.Task(
'train',
setup=setup,
run=run,
)
train.set_resources({sky.Resources(accelerators='V100')})
sky.launch(dag)