-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_language_dataset_creator.py
260 lines (217 loc) · 9.65 KB
/
multi_language_dataset_creator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import argparse
import json
import tensorflow_datasets as tfds
from tqdm import tqdm
import tiktoken
import os
from datetime import datetime
def load_wiki40b(langs):
"""
Load the wiki40b datasets for the specified languages.
Args:
langs (list): List of language codes.
Returns:
dict: Dictionary of datasets for each language.
"""
print("Step 1/7: Loading Wiki40B datasets...")
datasets = {}
for lang in langs:
datasets[lang] = {
'train': tfds.load(f'wiki40b/{lang}', split='train'),
'validation': tfds.load(f'wiki40b/{lang}', split='validation')
}
print("Step 1/7: Completed loading Wiki40B datasets.")
return datasets
def extract_wikidata_ids(datasets):
"""
Extract Wikidata IDs for both training and validation sets for each language.
Args:
datasets (dict): Dictionary of datasets.
Returns:
dict: Dictionary of Wikidata IDs for each language and split.
"""
print("Step 2/7: Extracting Wikidata IDs...")
wikidata_ids = {lang: {'train': set(), 'validation': set()} for lang in datasets}
for lang, splits in datasets.items():
for split in splits:
for example in tfds.as_numpy(datasets[lang][split]):
wikidata_ids[lang][split].add(example['wikidata_id'])
print("Step 2/7: Completed extracting Wikidata IDs.")
return wikidata_ids
def find_common_ids(wikidata_ids, langs):
"""
Identify common Wikidata IDs across all languages for both training and validation sets.
Args:
wikidata_ids (dict): Dictionary of Wikidata IDs for each language and split.
langs (list): List of language codes.
Returns:
dict: Dictionary of common IDs for training and validation sets.
"""
print("Step 3/7: Finding common Wikidata IDs...")
common_ids = {split: set.intersection(*[wikidata_ids[lang][split] for lang in langs])
for split in ['train', 'validation']}
print("Step 3/7: Completed finding common Wikidata IDs.")
return common_ids
def build_final_datasets(datasets, common_ids, langs):
"""
Retrieve the corresponding articles for each common ID in both training and validation sets.
Args:
datasets (dict): Dictionary of datasets for each language.
common_ids (dict): Dictionary of common IDs for training and validation sets.
langs (list): List of language codes.
Returns:
dict: Dictionary of final datasets for training and validation sets.
"""
print("Step 4/7: Building final datasets...")
final_datasets = {'train': [], 'validation': []}
dataset_hashmaps = {lang: {split: {} for split in ['train', 'validation']} for lang in langs}
for lang in langs:
for split in ['train', 'validation']:
for example in tfds.as_numpy(datasets[lang][split]):
dataset_hashmaps[lang][split][example['wikidata_id']] = example
for split in common_ids:
for common_id in common_ids[split]:
article = {lang: dataset_hashmaps[lang][split][common_id] for lang in langs}
final_datasets[split].append(article)
print("Step 4/7: Completed building final datasets.")
return final_datasets
def decode_bytes(data):
"""
Decode byte strings in the dataset to regular strings.
Args:
data: Data to decode.
Returns:
Decoded data.
"""
if isinstance(data, bytes):
return data.decode('utf-8')
if isinstance(data, dict):
return {k: decode_bytes(v) for k, v in data.items()}
if isinstance(data, list):
return [decode_bytes(v) for v in data]
return data
def deep_decode(obj):
"""
Recursively decode all byte objects within a given data structure.
Args:
obj: The data structure to decode.
Returns:
The decoded data structure.
"""
if isinstance(obj, dict):
return {k: deep_decode(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [deep_decode(elem) for elem in obj]
elif isinstance(obj, bytes):
return obj.decode('utf-8')
else:
return obj
def export_data(final_datasets, output_dir, langs):
"""
Export both training and validation datasets to JSON files.
Args:
final_datasets (dict): Dictionary of final datasets for training and validation sets.
output_dir (str): Output directory to save datasets.
langs (list): List of language codes.
"""
print("Step 5/7: Exporting data to JSON files...")
final_datasets_decoded = deep_decode(final_datasets)
langs_str = '-'.join(langs)
for split in final_datasets_decoded:
file_name = f'wiki40b_{langs_str}_{split}.json'
with open(os.path.join(output_dir, file_name), 'w', encoding='utf-8') as file:
json.dump(final_datasets_decoded[split], file, ensure_ascii=False, indent=4)
print("Step 5/7: Completed exporting data to JSON files.")
def calculate_token_counts(dataset):
"""
Calculate the total token count for a dataset.
Args:
dataset (list): List of articles.
Returns:
int: Total token count.
"""
encoder = tiktoken.get_encoding("gpt2")
return sum(len(encoder.encode(article['text'])) for article in dataset)
def enforce_token_limit(datasets, max_diff_percent):
"""
Enforce a token limit so that all datasets have a similar size in terms of tokens within a specified range.
Args:
datasets (dict): Dictionary of datasets for each split.
max_diff_percent (float): Maximum allowed percentage difference in token counts.
Returns:
dict: Adjusted datasets.
"""
print("Step 6/7: Enforcing token limits...")
encoder = tiktoken.get_encoding("gpt2")
# Helper function to decode text if it's in bytes
def decode_text(article):
for lang in article:
if isinstance(article[lang]['text'], bytes):
article[lang]['text'] = article[lang]['text'].decode('utf-8')
return article
# Calculate token counts for each language in the 'train' split
token_counts = {lang: calculate_token_counts([decode_text(article)[lang] for article in datasets['train']]) for lang in datasets['train'][0].keys()}
min_tokens = min(token_counts.values())
max_tokens = min_tokens * (1 + max_diff_percent / 100)
# Adjust training datasets to enforce token limit
adjusted_datasets = {'train': [], 'validation': datasets['validation']}
token_counts = {lang: 0 for lang in datasets['train'][0].keys()}
for article in datasets['train']:
article = decode_text(article)
article_tokens = {lang: len(encoder.encode(article[lang]['text'])) for lang in article.keys()}
# Check if adding this article would exceed the max token limit for any language
if all(token_counts[lang] + article_tokens[lang] <= max_tokens for lang in article.keys()):
adjusted_datasets['train'].append(article)
for lang in article.keys():
token_counts[lang] += article_tokens[lang]
print("Step 6/7: Completed enforcing token limits.")
return adjusted_datasets
def split_and_save_data(datasets, output_dir, langs):
"""
Save the training and validation datasets to JSON files.
Args:
datasets (dict): Dictionary of datasets for each split.
output_dir (str): Output directory to save datasets.
langs (list): List of language codes.
Returns:
tuple: New training and validation data.
"""
print("Step 7/7: Splitting and saving data...")
# Create the output directory with the specified format
date_str = datetime.now().strftime("%d.%m.%Y")
langs_str = '-'.join(langs)
output_dir = os.path.join(output_dir, f'datasets_{langs_str}_{date_str}')
os.makedirs(output_dir, exist_ok=True)
# Deep decode to ensure all text fields are decoded
datasets = deep_decode(datasets)
for split in datasets:
file_name = f'wiki40b_{langs_str}_{split}.json'
with open(os.path.join(output_dir, file_name), 'w', encoding='utf-8') as file:
json.dump(datasets[split], file, ensure_ascii=False, indent=4)
print("Step 7/7: Completed splitting and saving data.")
return datasets['train'], datasets['validation']
def main(args):
print("Starting the dataset creation process...")
langs = args.langs.split(',')
output_dir = args.output_dir
max_diff_percent = args.max_diff_percent
datasets = load_wiki40b(langs)
wikidata_ids = extract_wikidata_ids(datasets)
common_ids = find_common_ids(wikidata_ids, langs)
if args.common_articles:
final_datasets = build_final_datasets(datasets, common_ids, langs)
else:
final_datasets = {lang: {split: [decode_bytes(article) for article in tfds.as_numpy(datasets[lang][split])]
for split in ['train', 'validation']} for lang in langs}
adjusted_datasets = enforce_token_limit(final_datasets, max_diff_percent)
adjusted_datasets = deep_decode(adjusted_datasets) # Ensure all text is decoded after adjusting
split_and_save_data(adjusted_datasets, output_dir, langs)
print("Dataset creation process completed successfully.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Create a multi-language dataset.')
parser.add_argument('--langs', type=str, default='en,fr,de', help='Comma-separated list of languages. (e.g. en,fr,de)')
parser.add_argument('--common_articles', action='store_true', help='Force samples from different languages to be from the same articles.')
parser.add_argument('--output_dir', type=str, default='.', help='Output directory to save datasets.')
parser.add_argument('--max_diff_percent', type=float, default=5.0, help='Maximum allowed percentage difference in token counts.')
args = parser.parse_args()
main(args)