-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert_line_aligned_to_jsonl.py
90 lines (72 loc) · 3 KB
/
convert_line_aligned_to_jsonl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import argparse
from pathlib import Path
import pandas as pd
import json
from tqdm import tqdm
import random
def set_args():
ap = argparse.ArgumentParser()
ap.add_argument('--data_dir', type=Path)
ap.add_argument('--out_dir', type=Path, default=None)
ap.add_argument('--src_suffix', type=str, default='complex')
ap.add_argument('--tgt_suffix', type=str, default='simple')
ap.add_argument('--splits', type=str, nargs='*', default=['train', 'test', 'valid'])
ap.add_argument('--levels', type=str, nargs='*', default=[1, 2, 3, 4])
ap.add_argument('--label_src', action="store_true")
ap.add_argument('--dataset', type=str, choices=['muss', 'newsela_manual', 'newsela_auto'])
return ap.parse_args()
def write_to_json(src_lines, tgt_lines, outfile):
with open(outfile, 'w', encoding='utf8') as outf:
for s, t in zip(src_lines, tgt_lines):
outf.write(json.dumps({'complex': s, 'simple': t}, ensure_ascii=False) + '\n')
def read_lines(file):
with open(file, 'r', encoding='utf8') as inf:
lines = []
for line in tqdm(inf):
lines.append(line.strip())
return lines
def convert_paraphrase_data_to_jsonl(args):
for split in args.splits:
src_lines = read_lines(args.data_dir / f'{split}.{args.src_suffix}')
tgt_lines = read_lines(args.data_dir / f'{split}.{args.tgt_suffix}')
if args.out_dir is not None:
write_to_json(src_lines, tgt_lines, args.out_dir / f'{split}.json')
else:
write_to_json(src_lines, tgt_lines, args.data_dir / f'{split}.json')
print(f'Wrote {len(src_lines)} lines to output file')
return
def convert_newsela_data_to_jsonl(args):
for split in args.splits:
src_lines = []
tgt_lines = []
for level in args.levels:
infile = args.data_dir / f'{args.dataset}_v0_v{level}_{split}.tsv'
with open(infile, 'r', encoding='utf8') as inf:
for line in inf:
src, tgt = line.strip().split('\t')
if args.label_src:
src_lines.append(f'<l{level}> {src}')
else:
src_lines.append(src)
tgt_lines.append(tgt)
assert len(src_lines) == len(tgt_lines)
if split == 'train': # shuffle
print(len(src_lines))
c = set(zip(src_lines, tgt_lines))
random.seed(4)
c = random.sample(c, len(c))
src_lines, tgt_lines = zip(*c)
print(len(src_lines))
if args.out_dir is not None:
write_to_json(src_lines, tgt_lines, args.out_dir / f'{split}.json')
return
if __name__ == "__main__":
args = set_args()
Path(args.out_dir).mkdir(parents=True, exist_ok=True)
if args.dataset in ['newsela_manual', 'newsela_auto']:
convert_newsela_data_to_jsonl(args)
elif args.dataset == 'muss':
convert_paraphrase_data_to_jsonl(args)