-
Notifications
You must be signed in to change notification settings - Fork 80
/
dataset.py
executable file
·90 lines (75 loc) · 3.15 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import soundfile as sf
import numpy as np
from transformers import WhisperFeatureExtractor
class SALMONNDataset(Dataset):
def __init__(self, ann_path, whisper_path):
super().__init__()
self.annotation = json.load(open(ann_path, "r"))["annotation"]
self.wav_processor = WhisperFeatureExtractor.from_pretrained(whisper_path)
def __len__(self):
return len(self.annotation)
def collater(self, samples):
samples_spectrogram = [s["spectrogram"] for s in samples]
cat_spectrogram = torch.stack(samples_spectrogram, dim=0)
raw_wav = [torch.from_numpy(s["raw_wav"]) for s in samples]
raw_wav_length = torch.tensor([len(s["raw_wav"]) for s in samples])
raw_wav = pad_sequence(raw_wav, batch_first=True, padding_value=0)
paddding_mask = torch.arange(raw_wav.size(1)).unsqueeze(0) >= raw_wav_length.unsqueeze(1)
text = [s["text"] for s in samples]
task = [s["task"] for s in samples]
Q = [s["Q"] for s in samples]
id = [s["id"] for s in samples]
return {
"spectrogram": cat_spectrogram,
"raw_wav": raw_wav,
"padding_mask": paddding_mask,
"text": text,
"task": task,
"Q": Q,
"id": id,
}
def __getitem__(self, index):
ann = self.annotation[index]
audio, sr = sf.read(ann["path"])
if len(audio.shape) == 2: # stereo to mono
audio = audio[:, 0]
if "expand_wav" in ann:
for p in ann["expand_wav"]:
expand_audio, _ = sf.read(p)
if len(expand_audio.shape) == 2:
expand_audio = expand_audio[:, 0]
sil = np.zeros(1600, dtype=float)
audio = np.concatenate((audio, sil, expand_audio), axis=0)
if len(audio) < sr: # pad audio to at least 1s
sil = np.zeros(sr - len(audio), dtype=float)
audio = np.concatenate((audio, sil), axis=0)
audio = audio[: sr * 30] # truncate audio to at most 30s
spectrogram = self.wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"].squeeze()
text = ann["text"]
task = ann.get("task", "asr")
Q = ann.get("Q", "")
return {
"spectrogram": spectrogram,
"raw_wav": audio,
"text": text,
"task": task,
"Q": Q,
"id": ann["path"],
}