-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_wikibio_lstm_rl.cfg
89 lines (72 loc) · 2.61 KB
/
train_wikibio_lstm_rl.cfg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Model/Embeddings
word_vec_size: 300 # Word embedding size for src and tgt
share_embeddings: True # Share embeddings from src and tgt
# Model/Embedding Features
feat_vec_size: -1 # Attribute embedding size. -1 means <nb_features>**0.7
feat_merge: mlp # Merge action for incorporating feature embeddings [concat|sum|mlp]
# Model Structure
model_type: text # Type of source model to use [text|img|audio]
model_dtype: fp32
encoder_type: brnn # Type of encoder [rnn|brnn|transformer|cnn]
decoder_type: rnn # Type of decoder [rnn|transformer|cnn]
bidirectional_encoder: true
param_init: 0.1 # Uniform distribution with support (-param_init, +param_init)
layers: 2
rnn_size: 300
input_feed: 1
bridge: True
rnn_type: LSTM
# Model/Attention
global_attention: general # Type of attn to use [dot|general|mlp|none]
global_attention_function: softmax # [softmax|sparsemax]
generator_function: softmax
# Model/Copy
copy_attn: True
reuse_copy_attn: True # Reuse standard attention for copy
copy_attn_force: True # When available, train to copy
# Files and logs
data: experiments/wikibio/pretraining-lstm/data/data # path to datafile from preprocess.py
save_model: experiments/wikibio/lstm-rl/models/model # path to store checkpoints
log_file: experiments/wikibio/lstm-rl/train-log.txt
# Rl parameters
train_with_rl: true
train_from: experiments/wikibio/pretraining-lstm/models/model_step_30000.pt
rl_gamma_loss: .99
rl_metric: parent
max_generator_batches: 0 # It was easier to code if we had this paremeter at 0. Fundamentally doesn't change anything
# Precomputed stats for RL training
TABLE_VALUES: data/wikibio/TABLE_VALUES.json
REF_NGRAM_COUNTS: data/wikibio/REF_NGRAM_COUNTS.json
REF_NGRAM_WEIGHTS: data/wikibio/REF_NGRAM_WEIGHTS.json
references: data/wikibio/train_output.txt
report_every: 10 # log current loss every X steps
save_checkpoint_steps: 200 # save a cp every X steps
# Gpu related:
gpu_ranks: [0] # ids of gpus to use
world_size: 1 # total number of distributed processes
gpu_backend: nccl # type of torch distributed backend
gpu_verbose_level: 0
master_ip: localhost
master_port: 10000
seed: 123
# Optimization & training
batch_size: 8
batch_type: sents
normalization: sents
accum_count: [64] # Update weights every X batches
accum_steps: [0] # steps at which accum counts value changes
valid_steps: 100000 # run models on validation set every X steps
train_steps: 100000
optim: adam
max_grad_norm: 5
dropout: .3
adam_beta1: 0.9
adam_beta2: 0.999
label_smoothing: 0.0
average_decay: 0
average_every: 1
# Learning rate
learning_rate: 0.001
learning_rate_decay: 0.5 # lr *= lr_decay
start_decay_step: 5000
decay_steps: 10000