-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
156 lines (127 loc) · 6.1 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
from typing import Any
from functools import partial
import datasets
from transformers import PreTrainedTokenizerBase
from nltk import sent_tokenize
from rouge_score import rouge_scorer
from transformers import AutoTokenizer
import argparse
from datasets import Dataset
Dataset.cleanup_cache_files
import ipdb
import warnings
warnings.filterwarnings('ignore')
def convert_to_features(
examples: Any,
tokenizer: PreTrainedTokenizerBase,
padding: str,
max_source_length: int,
max_target_length: int,
src_text_column_name: str,
tgt_text_column_name: str,
):
inputs, targets = [], []
all_sent_rouge_scores = []
for i in range(len(examples[src_text_column_name])):
if examples[src_text_column_name][i] is not None and examples[tgt_text_column_name][i] is not None:
input_sentences = sent_tokenize(examples[src_text_column_name][i])
target_sentences = examples[tgt_text_column_name][i].strip()
rouge_scores = []
for sent in input_sentences:
rouge_scores.append(rouge_scorer.score(target_sentences, sent)['rougeLsum'].fmeasure)
# todo: add bos_token this way is unsafe
inputs.append(bosent_token.join(input_sentences))
targets.append(target_sentences.replace('\n', ' ').replace(' ', ' '))
all_sent_rouge_scores.append(rouge_scores)
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
# replace bos_token_id at the begining of document with bosent_token_id
for i in range(len(model_inputs['input_ids'])):
model_inputs['input_ids'][i][0] = bosent_token_id
all_token_sent_id = []
for sent_tokens in model_inputs['input_ids']:
sid = -1
token_sent_id = []
for tid in sent_tokens:
if tid == bosent_token_id:
sid += 1
if tid == tokenizer.eos_token_id or tid == tokenizer.pad_token_id:
sid = -1
token_sent_id.append(sid)
all_token_sent_id.append(token_sent_id)
all_token_info_dist = []
all_sent_bos_idx = []
for token_sent_id, sent_rouge_scores in zip(all_token_sent_id, all_sent_rouge_scores):
sent_rouge_scores = sent_rouge_scores[:max(token_sent_id) + 1] # truncation
sent_bos_idx = []
token_info_dist = []
bos_idx = 0
for sid in range(max(token_sent_id) + 1):
tnum = token_sent_id.count(sid)
sent_score = sent_rouge_scores[sid]
token_info_dist.extend([sent_score for _ in range(tnum)])
sent_bos_idx.extend([bos_idx for _ in range(tnum)])
bos_idx += tnum
token_info_dist.extend([-1 for _ in range(token_sent_id.count(-1))])
all_token_info_dist.append(token_info_dist)
sent_bos_idx.extend([0 for _ in range(token_sent_id.count(-1))])
all_sent_bos_idx.append(sent_bos_idx)
for i in range(len(all_token_sent_id)):
for j in range(len(all_token_sent_id[i])):
all_token_sent_id[i][j] += 1
model_inputs['info_distribution'] = all_token_info_dist
model_inputs['sentence_bos_index'] = all_sent_bos_idx
model_inputs['sent_id'] = all_token_sent_id
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--dataset', '-d', type=str, default='plos', help='dataset_name')
parser.add_argument('--tokenizer', '-t', type=str, default='facebook/bart-large', help='tokenizer name')
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
# if args.dataset == 'plos':
# dataset = datasets.load_dataset("parquet", data_files={'train': ['data/plos/train/0000.parquet', 'data/plos/train/0001.parquet'],
# 'validation' : 'data/plos/validation/0000.parquet',
# 'test': 'data/plos/test/0000.parquet'})
# elif args.dataset == 'elife':
# dataset = datasets.load_dataset("parquet", data_files={'train': 'data/elife/train/0000.parquet',
# 'validation' : 'data/elife/validation/0000.parquet',
# 'test': 'data/elife/test/0000.parquet'})
# else:
# SystemExit("Dataset unavaliable")
dataset = datasets.load_dataset("tomasg25/scientific_lay_summarisation", f"{args.dataset}")
src_text_column_name, tgt_text_column_name = "article", "summary"
max_source_length, max_target_length = 1024, 128
n_proc = 16
# Since bos_token is used as the beginning of target sequence,
# we use mask_token to represent the beginning of each sentence.
bosent_token = "<mask>"
bosent_token_id = 50264
rouge_scorer = rouge_scorer.RougeScorer(['rougeLsum'], use_stemmer=True)
tokenizer = AutoTokenizer.from_pretrained(f"{args.tokenizer}", use_fast=False)
convert_to_features = partial(
convert_to_features,
tokenizer=tokenizer,
padding='max_length',
max_source_length=max_source_length,
max_target_length=max_target_length,
src_text_column_name=src_text_column_name,
tgt_text_column_name=tgt_text_column_name,
)
dataset = dataset.map(
convert_to_features,
batched=True,
num_proc=n_proc,
)
cols_to_keep = ["input_ids", "attention_mask", "labels", "info_distribution", "sentence_bos_index", "sent_id"]
dataset.set_format(columns=cols_to_keep)
for split in ['train', 'validation', 'test']:
with open(f'data/{args.dataset}/{split}.json', 'w') as outfile:
for i, example in enumerate(dataset[split]):
json_string = json.dumps(example)
outfile.write(json_string + '\n')