-
Notifications
You must be signed in to change notification settings - Fork 499
/
convert_dataset.py
69 lines (58 loc) · 2.61 KB
/
convert_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Convert alpaca dataset into sharegpt format.
Usage: python convert_dataset.py --in_file alpaca_data.json --out_file alpaca_data_sharegpt.jsonl
"""
import argparse
from datasets import load_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in_file", type=str, required=True)
parser.add_argument("--out_file", type=str, required=True)
parser.add_argument("--data_type", type=str, default='alpaca', help="alpaca, qa, or sharegpt")
parser.add_argument("--file_type", type=str, default='json', help='input file type, json or csv')
args = parser.parse_args()
print(args)
data_files = {"train": args.in_file}
if args.file_type == 'csv':
if args.data_type in ['qa']:
column_names = ['input', 'output']
else:
column_names = ['instruction', 'input', 'output']
raw_datasets = load_dataset('csv', data_files=data_files, column_names=column_names, delimiter='\t')
elif args.file_type in ['json', 'jsonl']:
raw_datasets = load_dataset('json', data_files=data_files)
else:
raise ValueError("File type not supported")
ds = raw_datasets['train']
def process_qa(examples):
convs = []
for q, a in zip(examples['input'], examples['output']):
convs.append([
{"from": "human", "value": q},
{"from": "gpt", "value": a}
])
return {"conversations": convs}
def process_alpaca(examples):
convs = []
for instruction, inp, output in zip(examples['instruction'], examples['input'], examples['output']):
if inp and len(inp.strip()) > 0:
instruction = instruction + '\n\n' + inp
q = instruction
a = output
convs.append([
{"from": "human", "value": q},
{"from": "gpt", "value": a}
])
return {"conversations": convs}
if args.data_type in ['alpaca']:
ds = ds.map(process_alpaca, batched=True, remove_columns=ds.column_names, desc="Running process")
elif args.data_type in ['qa']:
ds = ds.map(process_qa, batched=True, remove_columns=ds.column_names, desc="Running process")
else:
# Other sharegpt dataset, need rename to conversations and remove unused columns
if "items" in ds.column_names:
ds = ds.rename(columns={"items": "conversations"})
columns_to_remove = ds.column_names.copy()
columns_to_remove.remove('conversations')
ds = ds.remove_columns(columns_to_remove)
ds.to_json(f"{args.out_file}", lines=True, force_ascii=False)