-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmelody_pipeline.py
executable file
·86 lines (76 loc) · 3.35 KB
/
melody_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from melody_lib import MelodySequence, extract_melodies, extract_melodies_for_info
from magenta.protobuf import music_pb2
from magenta.pipelines import pipeline
from magenta.pipelines import statistics
from melody_lib import MAX_EVENT_LENGTH
class MelodyExtractor(pipeline.Pipeline):
def __init__(self,
min_unique_pitches=5,
max_melody_events=MAX_EVENT_LENGTH,
min_melody_events=9,
filter_drums=True,
name=None):
super(MelodyExtractor, self).__init__(
input_type=music_pb2.NoteSequence,
output_type=MelodySequence,
name=name)
self._min_unique_pitches = min_unique_pitches
self._max_melody_events = max_melody_events
self._min_melody_events = min_melody_events
self._filter_drums = filter_drums
def transform(self, quantized_sequence):
try:
melodies, stats = extract_melodies(
quantized_sequence,
min_unique_pitches=self._min_melody_events,
max_melody_events=self._max_melody_events,
min_melody_events=self._min_melody_events,
filter_drums=self._filter_drums)
except Exception as e:
print('Skipped sequence:', str(e))
melodies = []
stats = [statistics.Counter('unknow_error', 1)]
self._set_stats(stats)
return melodies
class MelodyExtractorInfo(pipeline.Pipeline):
'''used to get the whole dataset info'''
def __init__(self,
min_unique_pitches=5,
max_melody_events=MAX_EVENT_LENGTH,
min_melody_events=9,
filter_drums=True,
name=None):
super(MelodyExtractorInfo, self).__init__(
input_type=music_pb2.NoteSequence,
output_type=MelodySequence,
name=name)
self._min_unique_pitches = min_unique_pitches
self._max_melody_events = max_melody_events
self._min_melody_events = min_melody_events
self._filter_drums = filter_drums
# def transform(self, quantized_sequence):
# try:
# melodies, stats = extract_melodies_for_info(
# quantized_sequence,
# min_unique_pitches=self._min_melody_events,
# max_melody_events=self._max_melody_events,
# min_melody_events=self._min_melody_events,
# filter_drums=self._filter_drums)
# except Exception as e:
# print('Skipped sequence:', str(e))
# melodies = []
# stats = [statistics.Counter('unknow_error', 1)]
# self._set_stats(stats)
# return melodies
def transform(self, quantized_sequence):
melodies, stats = extract_melodies_for_info(
quantized_sequence,
min_unique_pitches=self._min_melody_events,
max_melody_events=self._max_melody_events,
min_melody_events=self._min_melody_events,
filter_drums=self._filter_drums)
# print('Skipped sequence:', str(e))
# melodies = []
# stats = [statistics.Counter('unknow_error', 1)]
self._set_stats(stats)
return melodies