-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
57 lines (40 loc) · 1.4 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
from typing import Optional, Tuple, Iterable
import pandas as pd
import constants
from midi import MIDI
CSV_FILE = 'maestro-v3.0.0.csv'
TSplit = Tuple[Iterable[str], Iterable[str]]
class Dataset:
def __init__(self,
composer: Optional[str] = None,
):
self.index = pd.read_csv(os.path.join(constants.maestro_path, CSV_FILE))
if composer is not None:
self.index = self.index[self.index['canonical_composer'].str.contains(composer, case=False)]
def _iter_split(self, split: str) -> TSplit:
dataset = self.index[self.index['split'] == split]
midi_paths = dataset['midi_filename'].map(lambda path: os.path.join(constants.maestro_path, path))
return dataset['canonical_title'], midi_paths
@property
def train(self) -> TSplit:
return self._iter_split('train')
@property
def test(self) -> TSplit:
return self._iter_split('test')
@property
def validation(self) -> TSplit:
return self._iter_split('validation')
def main():
ds = Dataset(composer='Bach')
def count(gen):
return sum(1 for _ in gen)
print('train', count(ds.train[1]))
print('test', count(ds.test[1]))
print('val', count(ds.validation[1]))
# Load files
_, paths = ds.train
for midi in map(MIDI().from_midi, paths):
print(midi)
if __name__ == '__main__':
main()