-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_cotrain.py
82 lines (73 loc) · 4.55 KB
/
run_cotrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import logging
import argparse
from co_training import CoTraining
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def main():
# python run_cotrain.py --ext_output_dir ext_data --modelA_dir baseline_model --modelB_dir onto_model --de_unlabel_dir machine_translation/2017_de_sents.txt --en_unlabel_dir machine_translation/2017_en_sents.txt --k 10 --u 10 --top_n 3 --save_preds --save_agree
# python run_cotrain.py --ext_output_dir ext_data_1000 --modelA_dir baseline_model --modelB_dir onto_model --de_unlabel_dir machine_translation/2017_de_sents.txt --en_unlabel_dir machine_translation/2017_en_sents.txt --k 1000 --u 100 --top_n 10 --save_preds --save_agree
#python run_ner.py --data_dir data/full-isw-release.tsv --bert_model bert-base-german-cased --output_dir baseline_model/ --max_seq_length 128 --do_train --extend_L --ext_data_dir ext_data_1000 --ext_output_dir ext_isw_model
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--ext_output_dir",
default='ext_data/',
type=str,
required=True,
help="The dir that you save the extended L set.")
parser.add_argument("--modelA_dir",
default='baseline_model/',
type=str,
required=True,
help="The dir of pre-trained model that will be used in the cotraining algorithm on the X1 feature set, e.g. German.")
parser.add_argument("--modelB_dir",
default='onto_model/',
type=str,
required=True,
help="The dir of another pre-trained model can be specified to be used on the X2 feature set, e.g. English.")
parser.add_argument("--de_unlabel_dir",
default='machine_translation/2017_de_sents.txt',
type=str,
required=True,
help="The dir of unlabeled sentences in German.")
parser.add_argument("--en_unlabel_dir",
default='machine_translation/2017_en_sents.txt',
type=str,
required=True,
help="The dir of unlabeled sentences in English.")
parser.add_argument("--save_preds",
action='store_true',
help="Whether to save the confident predictions.")
parser.add_argument("--save_agree",
action='store_true',
help="Whether to save the agree predictions, aka. the predictions that will be added to L set.")
parser.add_argument("--top_n",
default=5,
type=int,
help="The number of the most confident examples that will be 'labeled' by each classifier during each iteration")
parser.add_argument("--k",
default=30,
type=int,
help="The number of iterations. The default is 30")
parser.add_argument("--u",
default=75,
type=int,
help="The size of the pool of unlabeled samples from which the classifier can choose. Default - 75")
args = parser.parse_args()
# Initialize co-training class
if os.path.exists(args.ext_output_dir) and os.listdir(args.ext_output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.ext_output_dir))
if not os.path.exists(args.ext_output_dir):
os.makedirs(args.ext_output_dir)
co_train = CoTraining(modelA_dir=args.modelA_dir, modelB_dir=args.modelB_dir, save_preds=args.save_preds, top_n=args.top_n, k=args.k, u=args.u)
compare_agree_list = co_train.fit(ext_output_dir=args.ext_output_dir, de_unlabel_dir=args.de_unlabel_dir, en_unlabel_dir=args.en_unlabel_dir, save_agree=args.save_agree, save_preds=args.save_preds)
logger.info(" ***** Running Co-Training ***** ")
logger.info(" Model A = {}".format(args.modelA_dir))
logger.info(" Model B = {}".format(args.modelB_dir))
logger.info("Top_n: {}, iteration_k: {}, sample_pool_u: {}".format(args.top_n, args.k, args.u))
logger.info(" ***** Loading Agree Set ***** ")
logger.info(" Num of agree samples: {}".format(len(compare_agree_list)))
if __name__ == '__main__':
main()