-
Notifications
You must be signed in to change notification settings - Fork 2
/
prepare.py
42 lines (31 loc) · 1.26 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from utils import format_labels
from utils import io
import sys
import pdb
def prepare(data_id, cfg_path='./config.yml'):
print('start preparing')
cfg = io.load_yml(cfg_path, data_id)
data = io.load_csv(cfg['data_file'])
data.rename(columns={cfg['text_col']: 'text'}, inplace=True)
if 'add_col' in cfg.keys():
columns = cfg['add_col'] + ['text', cfg['label_col']]
data = data[columns]
else:
data = data[['text', cfg['label_col']]]
data.dropna(subset=['text', cfg['label_col']], inplace=True)
data.drop_duplicates(inplace=True)
data['seq_length'] = data.text.map(str.split).apply(len)
data['label'] = data[cfg['label_col']].apply(format_labels.sort,
args=[cfg['sep']])
data['str_label'] = data['label'].apply(format_labels.join)
unique_labels = format_labels.get_unique(data.label.tolist())
data['one_hot_labels'] = data['label'].apply(format_labels.encode_onehot,
args=[unique_labels])
io.to_pickle(data, cfg['pkl_file'])
if __name__ == '__main__':
DATA_ID = sys.argv[1]
if len(sys.argv) > 2:
cfg_path = sys.argv[2]
prepare(DATA_ID, cfg_path)
else:
prepare(DATA_ID)