Skip to content

Commit

Permalink
update the finetune codes for tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ewrfcas committed Dec 9, 2019
1 parent 68e99b2 commit d5b5c44
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
28 changes: 14 additions & 14 deletions DRCD_finetune_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import random
from tqdm import tqdm
import collections
from tokenizations.offical_tokenization import BertTokenizer
from preprocess.cmrc2018_preprocess import json2features
from tokenizations.official_tokenization import BertTokenizer
from preprocess.DRCD_preprocess import json2features


def print_rank0(*args):
Expand Down Expand Up @@ -54,12 +54,12 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):
parser = argparse.ArgumentParser()
tf.logging.set_verbosity(tf.logging.ERROR)

parser.add_argument('--gpu_ids', type=str, default='4,5,6,7')
parser.add_argument('--gpu_ids', type=str, default='1')

# training parameter
parser.add_argument('--train_epochs', type=int, default=2)
parser.add_argument('--n_batch', type=int, default=32)
parser.add_argument('--lr', type=float, default=2.5e-5)
parser.add_argument('--lr', type=float, default=3e-5)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--clip_norm', type=float, default=1.0)
parser.add_argument('--loss_scale', type=float, default=2.0 ** 15)
Expand All @@ -77,19 +77,19 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):

# data dir
parser.add_argument('--vocab_file', type=str,
default='check_points/pretrain_models/roberta_wwm_ext_large/vocab.txt')
default='check_points/pretrain_models/google_bert_base/vocab.txt')

parser.add_argument('--train_dir', type=str, default='dataset/DRCD/train_features_roberta512.json')
parser.add_argument('--dev_dir1', type=str, default='dataset/DRCD/dev_examples_roberta512.json')
parser.add_argument('--dev_dir2', type=str, default='dataset/DRCD/dev_features_roberta512.json')
parser.add_argument('--train_file', type=str, default='origin_data/DRCD/DRCD_training.json')
parser.add_argument('--dev_file', type=str, default='origin_data/DRCD/DRCD_dev.json')
parser.add_argument('--bert_config_file', type=str,
default='check_points/pretrain_models/roberta_wwm_ext_large/bert_config.json')
default='check_points/pretrain_models/google_bert_base/bert_config.json')
parser.add_argument('--init_restore_dir', type=str,
default='check_points/pretrain_models/roberta_wwm_ext_large/bert_model.ckpt')
default='check_points/pretrain_models/google_bert_base/bert_model.ckpt')
parser.add_argument('--checkpoint_dir', type=str,
default='check_points/DRCD/roberta_wwm_ext_large/')
default='check_points/DRCD/google_bert_base/')
parser.add_argument('--setting_file', type=str, default='setting.txt')
parser.add_argument('--log_file', type=str, default='log.txt')

Expand Down Expand Up @@ -119,7 +119,7 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):

if mpi_rank == 0:
tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
assert args.vocab_size == len(tokenizer.vocab)
# assert args.vocab_size == len(tokenizer.vocab)
if not os.path.exists(args.train_dir):
json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'),
args.train_dir], tokenizer, is_training=True)
Expand Down Expand Up @@ -199,6 +199,11 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):
clip_norm=args.clip_norm,
init_loss_scale=args.loss_scale)

if mpi_rank == 0:
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1)
else:
saver = None

for seed_ in args.seed:
best_f1, best_em = 0, 0
if mpi_rank == 0:
Expand All @@ -225,11 +230,6 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])

if mpi_rank == 0:
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1)
else:
saver = None

with tf.train.MonitoredTrainingSession(checkpoint_dir=None,
hooks=training_hooks,
config=config) as sess:
Expand Down
29 changes: 15 additions & 14 deletions cmrc2018_finetune_tf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
import os

import numpy as np
import tensorflow as tf
import os

try:
# horovod must be import before optimizer!
Expand All @@ -19,7 +20,7 @@
import random
from tqdm import tqdm
import collections
from tokenizations.offical_tokenization import BertTokenizer
from tokenizations.official_tokenization import BertTokenizer
from preprocess.cmrc2018_preprocess import json2features


Expand Down Expand Up @@ -56,12 +57,12 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):
parser = argparse.ArgumentParser()
tf.logging.set_verbosity(tf.logging.ERROR)

parser.add_argument('--gpu_ids', type=str, default='0,1,2,3')
parser.add_argument('--gpu_ids', type=str, default='2')

# training parameter
parser.add_argument('--train_epochs', type=int, default=2)
parser.add_argument('--n_batch', type=int, default=32)
parser.add_argument('--lr', type=float, default=2e-5)
parser.add_argument('--lr', type=float, default=3e-5)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--clip_norm', type=float, default=1.0)
parser.add_argument('--loss_scale', type=float, default=2.0 ** 15)
Expand All @@ -79,19 +80,19 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):

# data dir
parser.add_argument('--vocab_file', type=str,
default='check_points/pretrain_models/roberta_wwm_ext_large/vocab.txt')
default='check_points/pretrain_models/google_bert_base/vocab.txt')

parser.add_argument('--train_dir', type=str, default='dataset/cmrc2018/train_features_roberta512.json')
parser.add_argument('--dev_dir1', type=str, default='dataset/cmrc2018/dev_examples_roberta512.json')
parser.add_argument('--dev_dir2', type=str, default='dataset/cmrc2018/dev_features_roberta512.json')
parser.add_argument('--train_file', type=str, default='origin_data/cmrc2018/cmrc2018_train.json')
parser.add_argument('--dev_file', type=str, default='origin_data/cmrc2018/cmrc2018_dev.json')
parser.add_argument('--bert_config_file', type=str,
default='check_points/pretrain_models/roberta_wwm_ext_large/bert_config.json')
default='check_points/pretrain_models/google_bert_base/bert_config.json')
parser.add_argument('--init_restore_dir', type=str,
default='check_points/pretrain_models/roberta_wwm_ext_large/bert_model.ckpt')
default='check_points/pretrain_models/google_bert_base/bert_model.ckpt')
parser.add_argument('--checkpoint_dir', type=str,
default='check_points/cmrc2018/roberta_wwm_ext_large/')
default='check_points/cmrc2018/google_bert_base/')
parser.add_argument('--setting_file', type=str, default='setting.txt')
parser.add_argument('--log_file', type=str, default='log.txt')

Expand Down Expand Up @@ -121,7 +122,7 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):

if mpi_rank == 0:
tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
assert args.vocab_size == len(tokenizer.vocab)
# assert args.vocab_size == len(tokenizer.vocab)
if not os.path.exists(args.train_dir):
json2features(args.train_file, [args.train_dir.replace('_features_', '_examples_'),
args.train_dir], tokenizer, is_training=True)
Expand Down Expand Up @@ -201,6 +202,11 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):
clip_norm=args.clip_norm,
init_loss_scale=args.loss_scale)

if mpi_rank == 0:
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1)
else:
saver = None

for seed_ in args.seed:
best_f1, best_em = 0, 0
if mpi_rank == 0:
Expand All @@ -227,11 +233,6 @@ def data_generator(data, n_batch, shuffle=False, drop_last=False):
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])

if mpi_rank == 0:
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=1)
else:
saver = None

with tf.train.MonitoredTrainingSession(checkpoint_dir=None,
hooks=training_hooks,
config=config) as sess:
Expand Down

0 comments on commit d5b5c44

Please sign in to comment.