-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun.py
129 lines (87 loc) · 4.75 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys
import os
import logging
import argparse
import datetime
from config.config import ParamManager
from dataloader.base import DataManager
from model.manager import DRTManager
from model.base import TextRepresentation
from utils.functions import save_results
# from manager import ADBManager
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--logger_name', type=str, default='Detection', help="Logger name for open intent detection.")
parser.add_argument('--log_dir', type=str, default='logs', help="Logger directory.")
parser.add_argument("--dataset", default='banking', type=str, help="The name of the dataset to train selected")
parser.add_argument("--epsilon", default=1.0, type=float)
parser.add_argument("--known_cls_ratio", default=0.75, type=float, help="The number of known classes")
parser.add_argument("--labeled_ratio", default=1.0, type=float,
help="The ratio of labeled samples in the training set")
parser.add_argument("--train", action="store_true", help="Whether to train the model")
parser.add_argument("--pretrain", action="store_true", help="Whether to pre-train the model")
parser.add_argument("--save_model", action="store_true", help="save trained-model for open intent detection")
parser.add_argument('--seed', type=int, default=0, help="random seed for initialization")
parser.add_argument("--gpu_id", type=str, default='0', help="Select the GPU id")
parser.add_argument("--data_dir", default=sys.path[0] + './data', type=str,
help="The input data dir. Should contain the .csv files (or other data files) for the task.")
parser.add_argument("--output_dir", default='./saved_models', type=str,
help="The output directory where all train data will be written.")
parser.add_argument("--pretrain_model_dir", default='./', type=str,
help="The pretrain model directory.")
parser.add_argument("--model_dir", default='models', type=str,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--load_pretrained_method", default=None, type=str,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--result_dir", type=str, default='results', help="The path to save results")
parser.add_argument("--results_file_name", type=str, default='results.csv',
help="The file name of all the results.")
parser.add_argument("--save_results", action="store_true", help="save final results for open intent detection")
parser.add_argument("--loss_fct", default="CrossEntropyLoss", help="The loss function for training.")
parser.add_argument("--dataset_neg", default="SQUAD", help="")
parser.add_argument("--lr", default=2e-5, type=float)
parser.add_argument("--num_train_epochs", default=50, type=int)
parser.add_argument("--margin", default=1.0, type=float, help="TripletLoss hyper parameter")
args = parser.parse_args()
return args
def set_logger(args):
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
file_name = f"{args.dataset}_{args.known_cls_ratio}_{args.labeled_ratio}_{time}.log"
logger = logging.getLogger(args.logger_name)
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler(os.path.join(args.log_dir, file_name))
fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(message)s')
fh.setFormatter(fh_formatter)
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch_formatter = logging.Formatter('%(name)s - %(message)s')
ch.setFormatter(ch_formatter)
logger.addHandler(ch)
return logger
def run(args, data, text_encoder):
model = DRTManager(args, data, text_encoder, logger_name=args.logger_name)
if args.train:
print('training begin...')
model.train(args, data)
print('testing begin...')
outputs = model.test(args, data)
if args.save_results:
print('Results saved in %s', str(os.path.join(args.result_dir, args.results_file_name)))
save_results(args, outputs)
if __name__ == '__main__':
sys.path.append('.')
args = parse_arguments()
param = ParamManager(args)
args = param.args
# print(args)
print('='*60)
print('data loading')
data = DataManager(args, logger_name=args.logger_name)
print('text representation model loading')
text_encoder = TextRepresentation(args, data, logger_name=args.logger_name)
run(args, data, text_encoder)