From da74e2dfb0ae565d6d4e3a6f5a80d5308a4407eb Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 17:15:16 +0000 Subject: [PATCH] Port MIDI utils and tests (#1) * add skeleton * port midi.py * update path for maestro metadata json * add tests and ci --- .github/workflows/python-ci.yml | 35 + README.md | 7 +- ariautils/ __init__.py | 0 ariautils/config/config.json | 242 +++++++ ariautils/midi.py | 1150 +++++++++++++++++++++++++++++++ ariautils/utils/__init__.py | 39 ++ ariautils/utils/config.py | 17 + pyproject.toml | 68 ++ tests/__init__.py | 0 tests/assets/data/arabesque.mid | Bin 0 -> 16975 bytes tests/assets/results/.gitkeep | 0 tests/test_midi.py | 85 +++ 12 files changed, 1641 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/python-ci.yml create mode 100644 ariautils/ __init__.py create mode 100644 ariautils/config/config.json create mode 100644 ariautils/midi.py create mode 100644 ariautils/utils/__init__.py create mode 100644 ariautils/utils/config.py create mode 100644 pyproject.toml create mode 100644 tests/__init__.py create mode 100644 tests/assets/data/arabesque.mid create mode 100644 tests/assets/results/.gitkeep create mode 100644 tests/test_midi.py diff --git a/.github/workflows/python-ci.yml b/.github/workflows/python-ci.yml new file mode 100644 index 0000000..e66bd07 --- /dev/null +++ b/.github/workflows/python-ci.yml @@ -0,0 +1,35 @@ +name: Python CI + +on: + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: install + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: install + run: | + python -m pip install --upgrade pip + pip install .[dev] + + - name: black + run: | + black --check . + + - name: mypy + run: | + mypy ariautils + mypy tests + + - name: Run tests with pytest + run: | + pytest diff --git a/README.md b/README.md index 048cea7..45a7e1f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,5 @@ -# ariautils -MIDI tokenizers and pre-processing utils. +# aria-utils + +An extremely lightweight and simple library for pre-processing and tokenizing MIDI files. + + diff --git a/ariautils/ __init__.py b/ariautils/ __init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ariautils/config/config.json b/ariautils/config/config.json new file mode 100644 index 0000000..465a324 --- /dev/null +++ b/ariautils/config/config.json @@ -0,0 +1,242 @@ +{ + "data": { + "tests": { + "max_programs":{ + "run": false, + "args": { + "max": 12 + } + }, + "max_instruments":{ + "run": false, + "args": { + "max": 7 + } + }, + "total_note_frequency":{ + "run": false, + "args": { + "min_per_second": 1.5, + "max_per_second": 30 + } + }, + "note_frequency_per_instrument":{ + "run": false, + "args": { + "min_per_second": 1.0, + "max_per_second": 25 + } + }, + "min_length":{ + "run": false, + "args": { + "min_seconds": 30 + } + } + }, + "pre_processing": { + "remove_instruments": { + "run": true, + "args": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + } + } + }, + "metadata": { + "functions": { + "composer_filename": { + "run": false, + "args": { + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] + } + }, + "composer_metamsg": { + "run": false, + "args": { + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] + } + }, + "form_filename": { + "run": false, + "args": { + "form_names": ["sonata", "prelude", "nocturne", "etude", "waltz", "mazurka", "impromptu", "fugue"] + } + }, + "maestro_json": { + "run": false, + "args": { + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"] + } + }, + "listening_model": { + "run": false, + "args": { + "tag_names": ["happy", "sad"] + } + }, + "abs_path": { + "run": true, + "args": {} + } + }, + "manual": { + "genre": ["classical", "jazz"], + "form": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "composer": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] + } + }, + "finetuning": { + "min_noisy_interval_ms": 5000, + "max_noisy_interval_ms": 60000, + "min_clean_interval_ms": 60000, + "max_clean_interval_ms": 200000, + "noising": { + "activation_prob": 0.95, + "remove_notes": { + "activation_prob": 0.75, + "min_ratio": 0.1, + "max_ratio": 0.4 + }, + "adjust_velocity": { + "activation_prob": 0.3, + "min_adjust": 1, + "max_adjust": 30, + "max_ratio": 0.1, + "min_ratio": 0.30 + }, + "adjust_onsets": { + "activation_prob": 0.5, + "min_adjust_s": 0.03, + "max_adjust_s": 0.07, + "max_ratio": 0.15, + "min_ratio": 0.5 + }, + "quantize_onsets": { + "activation_prob": 0.15, + "min_quant_s": 0.05, + "max_quant_s": 0.15, + "max_vel_delta": 45 + } + } + } + }, + + "tokenizer": { + "rel": { + "ignore_instruments": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + }, + "instrument_programs": { + "piano": 0, + "chromatic": 13, + "organ": 16, + "guitar": 24, + "bass": 32, + "strings": 40, + "ensemble": 48, + "brass": 56, + "reed": 64, + "pipe": 73, + "synth_lead": 80, + "synth_pad": 88, + "synth_effect": 96, + "ethnic": 104, + "percussive": 112, + "sfx": 120 + }, + "drum_velocity": 60, + "velocity_quantization": { + "step": 15 + }, + "time_quantization": { + "num_steps": 500, + "step": 10 + }, + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "genre_names": ["jazz", "classical"] + }, + "abs": { + "ignore_instruments": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + }, + "instrument_programs": { + "piano": 0, + "chromatic": 13, + "organ": 16, + "guitar": 24, + "bass": 32, + "strings": 40, + "ensemble": 48, + "brass": 56, + "reed": 64, + "pipe": 73, + "synth_lead": 80, + "synth_pad": 88, + "synth_effect": 96, + "ethnic": 104, + "percussive": 112, + "sfx": 120 + }, + "drum_velocity": 60, + "velocity_quantization": { + "step": 10 + }, + "abs_time_step_ms": 5000, + "max_dur_ms": 5000, + "time_step_ms": 10, + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "genre_names": ["jazz", "classical"] + }, + "lm": { + "tags": ["happy", "sad"] + } + } +} diff --git a/ariautils/midi.py b/ariautils/midi.py new file mode 100644 index 0000000..5977b25 --- /dev/null +++ b/ariautils/midi.py @@ -0,0 +1,1150 @@ +"""Utils for data/MIDI processing.""" + +import re +import os +import json +import hashlib +import unicodedata +import mido + +from collections import defaultdict +from pathlib import Path +from typing import ( + List, + Dict, + Any, + Tuple, + Final, + Concatenate, + Callable, + TypeAlias, + Literal, + TypedDict, +) + +from mido.midifiles.units import tick2second +from ariautils.utils import load_config, load_maestro_metadata_json + + +class MetaMessage(TypedDict): + """Meta message type corresponding text or copyright MIDI meta messages.""" + + type: Literal["text", "copyright"] + data: str + + +class TempoMessage(TypedDict): + """Tempo message type corresponding to the set_tempo MIDI message.""" + + type: Literal["tempo"] + data: int + tick: int + + +class PedalMessage(TypedDict): + """Sustain pedal message type corresponding to control_change 64 MIDI messages.""" + + type: Literal["pedal"] + data: Literal[0, 1] # 0 for off, 1 for on + tick: int + channel: int + + +class InstrumentMessage(TypedDict): + """Instrument message type corresponding to program_change MIDI messages.""" + + type: Literal["instrument"] + data: int + tick: int + channel: int + + +class NoteData(TypedDict): + pitch: int + start: int + end: int + velocity: int + + +class NoteMessage(TypedDict): + """Note message type corresponding to paired note_on and note_off MIDI messages.""" + + type: Literal["note"] + data: NoteData + tick: int + channel: int + + +MidiMessage: TypeAlias = ( + MetaMessage | TempoMessage | PedalMessage | InstrumentMessage | NoteMessage +) + + +class MidiDictData(TypedDict): + """Type for MidiDict attributes in dictionary form.""" + + meta_msgs: List[MetaMessage] + tempo_msgs: List[TempoMessage] + pedal_msgs: List[PedalMessage] + instrument_msgs: List[InstrumentMessage] + note_msgs: List[NoteMessage] + ticks_per_beat: int + metadata: Dict[str, Any] + + +class MidiDict: + """Container for MIDI data in dictionary form. + + Args: + meta_msgs (List[MetaMessage]): List of text or copyright MIDI meta messages. + tempo_msgs (List[TempoMessage]): List of tempo change messages. + pedal_msgs (List[PedalMessage]): List of sustain pedal messages. + instrument_msgs (List[InstrumentMessage]): List of program change messages. + note_msgs (List[NoteMessage]): List of note messages from paired note-on/off events. + ticks_per_beat (int): MIDI ticks per beat. + metadata (dict): Optional metadata key-value pairs (e.g., {"genre": "classical"}). + """ + + def __init__( + self, + meta_msgs: List[MetaMessage], + tempo_msgs: List[TempoMessage], + pedal_msgs: List[PedalMessage], + instrument_msgs: List[InstrumentMessage], + note_msgs: List[NoteMessage], + ticks_per_beat: int, + metadata: Dict[str, Any], + ): + self.meta_msgs = meta_msgs + self.tempo_msgs = tempo_msgs + self.pedal_msgs = pedal_msgs + self.instrument_msgs = instrument_msgs + self.note_msgs = sorted(note_msgs, key=lambda msg: msg["tick"]) + self.ticks_per_beat = ticks_per_beat + self.metadata = metadata + + # Tracks if resolve_pedal() has been called. + self.pedal_resolved = False + + # If tempo_msgs is empty, initalize to default + if not self.tempo_msgs: + DEFAULT_TEMPO_MSG: TempoMessage = { + "type": "tempo", + "data": 500000, + "tick": 0, + } + self.tempo_msgs = [DEFAULT_TEMPO_MSG] + # If tempo_msgs is empty, initalize to default (piano) + if not self.instrument_msgs: + DEFAULT_INSTRUMENT_MSG: InstrumentMessage = { + "type": "instrument", + "data": 0, + "tick": 0, + "channel": 0, + } + self.instrument_msgs = [DEFAULT_INSTRUMENT_MSG] + + self.program_to_instrument = self.get_program_to_instrument() + + @classmethod + def get_program_to_instrument(cls) -> Dict[int, str]: + """Return a map of MIDI program to instrument name.""" + + PROGRAM_TO_INSTRUMENT: Final[Dict[int, str]] = ( + {i: "piano" for i in range(0, 7 + 1)} + | {i: "chromatic" for i in range(8, 15 + 1)} + | {i: "organ" for i in range(16, 23 + 1)} + | {i: "guitar" for i in range(24, 31 + 1)} + | {i: "bass" for i in range(32, 39 + 1)} + | {i: "strings" for i in range(40, 47 + 1)} + | {i: "ensemble" for i in range(48, 55 + 1)} + | {i: "brass" for i in range(56, 63 + 1)} + | {i: "reed" for i in range(64, 71 + 1)} + | {i: "pipe" for i in range(72, 79 + 1)} + | {i: "synth_lead" for i in range(80, 87 + 1)} + | {i: "synth_pad" for i in range(88, 95 + 1)} + | {i: "synth_effect" for i in range(96, 103 + 1)} + | {i: "ethnic" for i in range(104, 111 + 1)} + | {i: "percussive" for i in range(112, 119 + 1)} + | {i: "sfx" for i in range(120, 127 + 1)} + ) + + return PROGRAM_TO_INSTRUMENT + + def get_msg_dict(self) -> MidiDictData: + """Returns MidiDict data in dictionary form.""" + + return { + "meta_msgs": self.meta_msgs, + "tempo_msgs": self.tempo_msgs, + "pedal_msgs": self.pedal_msgs, + "instrument_msgs": self.instrument_msgs, + "note_msgs": self.note_msgs, + "ticks_per_beat": self.ticks_per_beat, + "metadata": self.metadata, + } + + def to_midi(self) -> mido.MidiFile: + """Inplace version of dict_to_midi.""" + + return dict_to_midi(self.get_msg_dict()) + + @classmethod + def from_msg_dict(cls, msg_dict: MidiDictData) -> "MidiDict": + """Inplace version of midi_to_dict.""" + + assert msg_dict.keys() == { + "meta_msgs", + "tempo_msgs", + "pedal_msgs", + "instrument_msgs", + "note_msgs", + "ticks_per_beat", + "metadata", + } + + return cls(**msg_dict) + + @classmethod + def from_midi(cls, mid_path: str | Path) -> "MidiDict": + """Loads a MIDI file from path and returns MidiDict.""" + + mid = mido.MidiFile(mid_path) + return cls(**midi_to_dict(mid)) + + def calculate_hash(self) -> str: + msg_dict_to_hash = dict(self.get_msg_dict()) + + # Remove metadata before calculating hash + msg_dict_to_hash.pop("meta_msgs") + msg_dict_to_hash.pop("ticks_per_beat") + msg_dict_to_hash.pop("metadata") + + return hashlib.md5( + json.dumps(msg_dict_to_hash, sort_keys=True).encode() + ).hexdigest() + + def tick_to_ms(self, tick: int) -> int: + """Calculate the time (in milliseconds) in current file at a MIDI tick.""" + + return get_duration_ms( + start_tick=0, + end_tick=tick, + tempo_msgs=self.tempo_msgs, + ticks_per_beat=self.ticks_per_beat, + ) + + def _build_pedal_intervals(self) -> Dict[int, List[List[int]]]: + """Returns a mapping of channels to sustain pedal intervals.""" + + self.pedal_msgs.sort(key=lambda msg: msg["tick"]) + channel_to_pedal_intervals = defaultdict(list) + pedal_status: Dict[int, int] = {} + + for pedal_msg in self.pedal_msgs: + tick = pedal_msg["tick"] + channel = pedal_msg["channel"] + data = pedal_msg["data"] + + if data == 1 and pedal_status.get(channel, None) is None: + pedal_status[channel] = tick + elif data == 0 and pedal_status.get(channel, None) is not None: + # Close pedal interval + _start_tick = pedal_status[channel] + _end_tick = tick + channel_to_pedal_intervals[channel].append( + [_start_tick, _end_tick] + ) + del pedal_status[channel] + + # Close all unclosed pedals at end of track + final_tick = self.note_msgs[-1]["data"]["end"] + for channel, start_tick in pedal_status.items(): + channel_to_pedal_intervals[channel].append([start_tick, final_tick]) + + return channel_to_pedal_intervals + + def resolve_overlaps(self) -> "MidiDict": + """Resolves any note overlaps (inplace) between notes with the same + pitch and channel. This is achieved by converting a pair of notes with + the same pitch (a0): + + [a, b+x], [b-y, c] -> [a, b-y], [b-y, c] + + Note that this should not occur if the note messages have not been + modified, e.g., by resolve_overlap(). + """ + + # Organize notes by channel and pitch + note_msgs_c: Dict[int, Dict[int, List[NoteMessage]]] = defaultdict( + lambda: defaultdict(list) + ) + for msg in self.note_msgs: + _channel = msg["channel"] + _pitch = msg["data"]["pitch"] + note_msgs_c[_channel][_pitch].append(msg) + + # We can modify notes by reference as they are dictionaries + for channel, msgs_by_pitch in note_msgs_c.items(): + for pitch, msgs in msgs_by_pitch.items(): + msgs.sort( + key=lambda msg: (msg["data"]["start"], msg["data"]["end"]) + ) + prev_off_tick = -1 + for idx, msg in enumerate(msgs): + on_tick = msg["data"]["start"] + off_tick = msg["data"]["end"] + if prev_off_tick > on_tick: + # Adjust end of previous (idx - 1) msg to remove overlap + msgs[idx - 1]["data"]["end"] = on_tick + prev_off_tick = off_tick + + return self + + def resolve_pedal(self) -> "MidiDict": + """Extend note offsets according to pedal and resolve any note overlaps""" + + # If has been already resolved, we don't recalculate + if self.pedal_resolved == True: + print("Pedal has already been resolved") + + # Organize note messages by channel + note_msgs_c = defaultdict(list) + for msg in self.note_msgs: + _channel = msg["channel"] + note_msgs_c[_channel].append(msg) + + # We can modify notes by reference as they are dictionaries + channel_to_pedal_intervals = self._build_pedal_intervals() + for channel, msgs in note_msgs_c.items(): + for msg in msgs: + note_end_tick = msg["data"]["end"] + for pedal_interval in channel_to_pedal_intervals[channel]: + pedal_start, pedal_end = pedal_interval + if pedal_start < note_end_tick < pedal_end: + msg["data"]["end"] = pedal_end + break + + self.resolve_overlaps() + self.pedal_resolved = True + + return self + + # TODO: Needs to be refactored and tested + def remove_redundant_pedals(self) -> "MidiDict": + """Removes redundant pedal messages from the MIDI data in place. + + Removes all pedal on/off message pairs that don't extend any notes. + Makes an exception for pedal off messages that coincide exactly with + note offsets. + """ + + def _is_pedal_useful( + pedal_start_tick: int, + pedal_end_tick: int, + note_msgs: List[NoteMessage], + ) -> bool: + # This logic loops through the note_msgs that could possibly + # be effected by the pedal which starts at pedal_start_tick + # and ends at pedal_end_tick. If there is note effected by the + # pedal, then it returns early. + + note_idx = 0 + note_msg = note_msgs[0] + note_start = note_msg["data"]["start"] + + while note_start <= pedal_end_tick and note_idx < len(note_msgs): + note_msg = note_msgs[note_idx] + note_start, note_end = ( + note_msg["data"]["start"], + note_msg["data"]["end"], + ) + + if pedal_start_tick <= note_end <= pedal_end_tick: + # Found note for which pedal is useful + return True + + note_idx += 1 + + return False + + def _process_channel_pedals(channel: int) -> None: + pedal_msg_idxs_to_remove = [] + pedal_down_tick = None + pedal_down_msg_idx = None + + note_msgs = [ + msg for msg in self.note_msgs if msg["channel"] == channel + ] + + if not note_msgs: + # No notes to process. In this case we remove all pedal_msgs + # and then return early. + for pedal_msg_idx, pedal_msg in enumerate(self.pedal_msgs): + pedal_msg_value, pedal_msg_tick, _channel = ( + pedal_msg["data"], + pedal_msg["tick"], + pedal_msg["channel"], + ) + + if _channel == channel: + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + + # Remove messages + self.pedal_msgs = [ + msg + for _idx, msg in enumerate(self.pedal_msgs) + if _idx not in pedal_msg_idxs_to_remove + ] + return + + for pedal_msg_idx, pedal_msg in enumerate(self.pedal_msgs): + pedal_msg_value, pedal_msg_tick, _channel = ( + pedal_msg["data"], + pedal_msg["tick"], + pedal_msg["channel"], + ) + + # Only process pedal_msgs for specified MIDI channel + if _channel != channel: + continue + + # Remove never-closed pedal messages + if ( + pedal_msg_idx == len(self.pedal_msgs) - 1 + and pedal_msg_value == 1 + ): + # Current msg is last one and ON -> remove curr pedal_msg + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + + # Logic for removing repeated pedal messages and updating + # pedal_down_tick and pedal_down_idx + if pedal_down_tick is None: + if pedal_msg_value == 1: + # Pedal is OFF and current msg is ON -> update + pedal_down_tick = pedal_msg_tick + pedal_down_msg_idx = pedal_msg_idx + continue + else: + # Pedal is OFF and current msg is OFF -> remove curr pedal_msg + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + continue + else: + if pedal_msg_value == 1: + # Pedal is ON and current msg is ON -> remove curr pedal_msg + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + continue + + pedal_is_useful = _is_pedal_useful( + pedal_start_tick=pedal_down_tick, + pedal_end_tick=pedal_msg_tick, + note_msgs=note_msgs, + ) + + if pedal_is_useful is False: + # Pedal hasn't effected any notes -> remove + pedal_msg_idxs_to_remove.append(pedal_down_msg_idx) + pedal_msg_idxs_to_remove.append(pedal_msg_idx) + + # Finished processing pedal, set pedal state to OFF + pedal_down_tick = None + pedal_down_msg_idx = None + + # Remove messages + self.pedal_msgs = [ + msg + for _idx, msg in enumerate(self.pedal_msgs) + if _idx not in pedal_msg_idxs_to_remove + ] + + for channel in set([msg["channel"] for msg in self.pedal_msgs]): + _process_channel_pedals(channel) + + return self + + def remove_instruments(self, config: dict) -> "MidiDict": + """Removes all messages with instruments specified in config at: + + data.preprocessing.remove_instruments + + Note that drum messages, defined as those which occur on MIDI channel 9 + are not removed. + """ + + programs_to_remove = [ + i + for i in range(1, 127 + 1) + if config[self.program_to_instrument[i]] is True + ] + channels_to_remove = [ + msg["channel"] + for msg in self.instrument_msgs + if msg["data"] in programs_to_remove + ] + + # Remove drums (channel 9) from channels to remove + channels_to_remove = [i for i in channels_to_remove if i != 9] + + # Remove unwanted messages all type by looping over msgs types + _msg_dict: Dict[str, List] = { + "meta_msgs": self.meta_msgs, + "tempo_msgs": self.tempo_msgs, + "pedal_msgs": self.pedal_msgs, + "instrument_msgs": self.instrument_msgs, + "note_msgs": self.note_msgs, + } + + for msgs_name, msgs_list in _msg_dict.items(): + setattr( + self, + msgs_name, + [ + msg + for msg in msgs_list + if msg.get("channel", -1) not in channels_to_remove + ], + ) + + return self + + +# TODO: The sign has been changed. Make sure this function isn't used anywhere else +def _extract_track_data( + track: mido.MidiTrack, +) -> Tuple[ + List[MetaMessage], + List[TempoMessage], + List[PedalMessage], + List[InstrumentMessage], + List[NoteMessage], +]: + """Converts MIDI messages into format used by MidiDict.""" + + meta_msgs: List[MetaMessage] = [] + tempo_msgs: List[TempoMessage] = [] + pedal_msgs: List[PedalMessage] = [] + instrument_msgs: List[InstrumentMessage] = [] + note_msgs: List[NoteMessage] = [] + + last_note_on = defaultdict(list) + for message in track: + # Meta messages + if message.is_meta is True: + if message.type == "text" or message.type == "copyright": + meta_msgs.append( + { + "type": message.type, + "data": message.text, + } + ) + # Tempo messages + elif message.type == "set_tempo": + tempo_msgs.append( + { + "type": "tempo", + "data": message.tempo, + "tick": message.time, + } + ) + # Instrument messages + elif message.type == "program_change": + instrument_msgs.append( + { + "type": "instrument", + "data": message.program, + "tick": message.time, + "channel": message.channel, + } + ) + # Pedal messages + elif message.type == "control_change" and message.control == 64: + # Consistent with pretty_midi and ableton-live default behavior + pedal_msgs.append( + { + "type": "pedal", + "data": 0 if message.value < 64 else 1, + "tick": message.time, + "channel": message.channel, + } + ) + # Note messages + elif message.type == "note_on" and message.velocity > 0: + last_note_on[(message.note, message.channel)].append( + (message.time, message.velocity) + ) + elif message.type == "note_off" or ( + message.type == "note_on" and message.velocity == 0 + ): + # Ignore non-existent note-ons + if (message.note, message.channel) in last_note_on: + end_tick = message.time + open_notes = last_note_on[(message.note, message.channel)] + + notes_to_close = [ + (start_tick, velocity) + for start_tick, velocity in open_notes + if start_tick != end_tick + ] + notes_to_keep = [ + (start_tick, velocity) + for start_tick, velocity in open_notes + if start_tick == end_tick + ] + + for start_tick, velocity in notes_to_close: + note_msgs.append( + { + "type": "note", + "data": { + "pitch": message.note, + "start": start_tick, + "end": end_tick, + "velocity": velocity, + }, + "tick": start_tick, + "channel": message.channel, + } + ) + + if len(notes_to_close) > 0 and len(notes_to_keep) > 0: + # Note-on on the same tick but we already closed + # some previous notes -> it will continue, keep it. + last_note_on[(message.note, message.channel)] = ( + notes_to_keep + ) + else: + # Remove the last note on for this instrument + del last_note_on[(message.note, message.channel)] + + return meta_msgs, tempo_msgs, pedal_msgs, instrument_msgs, note_msgs + + +def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: + """Converts mid.MidiFile into MidiDictData representation. + + Additionally runs metadata extraction according to config specified at: + + data.metadata.functions + + Args: + mid (mido.MidiFile): A mido file object to parse. + + Returns: + MidiDictData: A dictionary containing extracted MIDI data including notes, + time signatures, key signatures, and other musical events. + """ + + metadata_config = load_config()["data"]["metadata"] + # Convert time in mid to absolute + for track in mid.tracks: + curr_tick = 0 + for message in track: + message.time += curr_tick + curr_tick = message.time + + midi_dict_data: MidiDictData = { + "meta_msgs": [], + "tempo_msgs": [], + "pedal_msgs": [], + "instrument_msgs": [], + "note_msgs": [], + "ticks_per_beat": mid.ticks_per_beat, + "metadata": {}, + } + + # Compile track data + for mid_track in mid.tracks: + meta_msgs, tempo_msgs, pedal_msgs, instrument_msgs, note_msgs = ( + _extract_track_data(mid_track) + ) + midi_dict_data["meta_msgs"] += meta_msgs + midi_dict_data["tempo_msgs"] += tempo_msgs + midi_dict_data["pedal_msgs"] += pedal_msgs + midi_dict_data["instrument_msgs"] += instrument_msgs + midi_dict_data["note_msgs"] += note_msgs + + # Sort by tick (for note msgs, this will be the same as data.start_tick) + midi_dict_data["tempo_msgs"] = sorted( + midi_dict_data["tempo_msgs"], key=lambda x: x["tick"] + ) + midi_dict_data["pedal_msgs"] = sorted( + midi_dict_data["pedal_msgs"], key=lambda x: x["tick"] + ) + midi_dict_data["instrument_msgs"] = sorted( + midi_dict_data["instrument_msgs"], key=lambda x: x["tick"] + ) + midi_dict_data["note_msgs"] = sorted( + midi_dict_data["note_msgs"], key=lambda x: x["tick"] + ) + + for metadata_process_name, metadata_process_config in metadata_config[ + "functions" + ].items(): + if metadata_process_config["run"] is True: + metadata_fn = get_metadata_fn( + metadata_process_name=metadata_process_name + ) + fn_args: Dict = metadata_process_config["args"] + + collected_metadata = metadata_fn(mid, midi_dict_data, **fn_args) + if collected_metadata: + for k, v in collected_metadata.items(): + midi_dict_data["metadata"][k] = v + + return midi_dict_data + + +def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: + """Converts MIDI information from dictionary form into a mido.MidiFile. + + This function performs midi_to_dict in reverse. + + Args: + mid_data (dict): MIDI information in dictionary form. + + Returns: + mido.MidiFile: The MIDI parsed from the input data. + """ + + assert mid_data.keys() == { + "meta_msgs", + "tempo_msgs", + "pedal_msgs", + "instrument_msgs", + "note_msgs", + "ticks_per_beat", + "metadata", + }, "Invalid json/dict." + + ticks_per_beat = mid_data["ticks_per_beat"] + + # Add all messages (not ordered) to one track + track = mido.MidiTrack() + end_msgs = defaultdict(list) + + for tempo_msg in mid_data["tempo_msgs"]: + track.append( + mido.MetaMessage( + "set_tempo", tempo=tempo_msg["data"], time=tempo_msg["tick"] + ) + ) + + for pedal_msg in mid_data["pedal_msgs"]: + track.append( + mido.Message( + "control_change", + control=64, + value=pedal_msg["data"] + * 127, # Stored in PedalMessage as 1 or 0 + channel=pedal_msg["channel"], + time=pedal_msg["tick"], + ) + ) + + for instrument_msg in mid_data["instrument_msgs"]: + track.append( + mido.Message( + "program_change", + program=instrument_msg["data"], + channel=instrument_msg["channel"], + time=instrument_msg["tick"], + ) + ) + + for note_msg in mid_data["note_msgs"]: + # Note on + track.append( + mido.Message( + "note_on", + note=note_msg["data"]["pitch"], + velocity=note_msg["data"]["velocity"], + channel=note_msg["channel"], + time=note_msg["data"]["start"], + ) + ) + # Note off + end_msgs[(note_msg["channel"], note_msg["data"]["pitch"])].append( + (note_msg["data"]["start"], note_msg["data"]["end"]) + ) + + # Only add end messages that don't interfere with other notes + for k, v in end_msgs.items(): + channel, pitch = k + for start, end in v: + add = True + for _start, _end in v: + if start < _start < end < _end: + add = False + + if add is True: + track.append( + mido.Message( + "note_on", + note=pitch, + velocity=0, + channel=channel, + time=end, + ) + ) + + # Magic sorting function + def _sort_fn(msg: mido.Message) -> Tuple[int, int]: + if hasattr(msg, "velocity"): + return (msg.time, msg.velocity) + else: + return (msg.time, 1000) + + # Sort and convert from abs_time -> delta_time + track = sorted(track, key=_sort_fn) + tick = 0 + for msg in track: + msg.time -= tick + tick += msg.time + + track.append(mido.MetaMessage("end_of_track", time=0)) + mid = mido.MidiFile(type=0) + mid.ticks_per_beat = ticks_per_beat + mid.tracks.append(track) + + return mid + + +def get_duration_ms( + start_tick: int, + end_tick: int, + tempo_msgs: List[TempoMessage], + ticks_per_beat: int, +) -> int: + """Calculates elapsed time (in ms) between start_tick and end_tick.""" + + # Finds idx such that: + # tempo_msg[idx]["tick"] < start_tick <= tempo_msg[idx+1]["tick"] + for idx, curr_msg in enumerate(tempo_msgs): + if start_tick <= curr_msg["tick"]: + break + if idx > 0: # Special case idx == 0 -> Don't -1 + idx -= 1 + + # It is important that we initialise curr_tick & curr_tempo here. In the + # case that there is a single tempo message the following loop will not run. + duration = 0.0 + curr_tick = start_tick + curr_tempo = tempo_msgs[idx]["data"] + + # Sums all tempo intervals before tempo_msgs[-1]["tick"] + for curr_msg, next_msg in zip(tempo_msgs[idx:], tempo_msgs[idx + 1 :]): + curr_tempo = curr_msg["data"] + if end_tick < next_msg["tick"]: + delta_tick = end_tick - curr_tick + else: + delta_tick = next_msg["tick"] - curr_tick + + duration += tick2second( + tick=delta_tick, + tempo=curr_tempo, + ticks_per_beat=ticks_per_beat, + ) + + if end_tick < next_msg["tick"]: + break + else: + curr_tick = next_msg["tick"] + + # Case end_tick > tempo_msgs[-1]["tick"] + if end_tick > tempo_msgs[-1]["tick"]: + curr_tempo = tempo_msgs[-1]["data"] + delta_tick = end_tick - curr_tick + + duration += tick2second( + tick=delta_tick, + tempo=curr_tempo, + ticks_per_beat=ticks_per_beat, + ) + + # Convert from seconds to milliseconds + duration = duration * 1e3 + duration = round(duration) + + return duration + + +def _match_word(text: str, word: str) -> bool: + def to_ascii(s: str) -> str: + # Remove accents + normalized = unicodedata.normalize("NFKD", s) + return "".join(c for c in normalized if not unicodedata.combining(c)) + + text = to_ascii(text) + word = to_ascii(word) + + # If name="bach" this pattern will match "bach", "Bach" or "BACH" if + # it is either proceeded or preceded by a "_" or " ". + pattern = ( + r"(^|[\s_])(" + + word.lower() + + r"|" + + word.upper() + + r"|" + + word.capitalize() + + r")([\s_]|$)" + ) + + if re.search(pattern, text, re.IGNORECASE): + return True + else: + return False + + +def meta_composer_filename( + mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list +) -> Dict[str, str]: + file_name = Path(str(mid.filename)).stem + matched_names_unique = set() + for name in composer_names: + if _match_word(file_name, name): + matched_names_unique.add(name) + + # Only return data if only one composer is found + matched_names = list(matched_names_unique) + if len(matched_names) == 1: + return {"composer": matched_names[0]} + else: + return {} + + +def meta_form_filename( + mid: mido.MidiFile, msg_data: MidiDictData, form_names: list +) -> Dict[str, str]: + file_name = Path(str(mid.filename)).stem + matched_names_unique = set() + for name in form_names: + if _match_word(file_name, name): + matched_names_unique.add(name) + + # Only return data if only one composer is found + matched_names = list(matched_names_unique) + if len(matched_names) == 1: + return {"form": matched_names[0]} + else: + return {} + + +def meta_composer_metamsg( + mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list +) -> Dict[str, str]: + matched_names_unique = set() + for msg in msg_data["meta_msgs"]: + for name in composer_names: + if _match_word(msg["data"], name): + matched_names_unique.add(name) + + # Only return data if only one composer is found + matched_names = list(matched_names_unique) + if len(matched_names) == 1: + return {"composer": matched_names[0]} + else: + return {} + + +# TODO: Needs testing +def meta_maestro_json( + mid: mido.MidiFile, + msg_data: MidiDictData, + composer_names: list, + form_names: list, +) -> Dict[str, str]: + """Loads composer and form metadata from MAESTRO metadata json file. + + + This should only be used when processing MAESTRO, it requires maestro.json + to be in the working directory. This json files contains MAESTRO metadata in + the form file_name: {"composer": str, "title": str}. + """ + + _file_name = Path(str(mid.filename)).name + _file_name_without_ext = os.path.splitext(_file_name)[0] + metadata = load_maestro_metadata_json().get( + _file_name_without_ext + ".midi", None + ) + if metadata == None: + return {} + + matched_forms_unique = set() + for form in form_names: + if _match_word(metadata["title"], form): + matched_forms_unique.add(form) + + matched_composers_unique = set() + for composer in composer_names: + if _match_word(metadata["composer"], composer): + matched_composers_unique.add(composer) + + res = {} + matched_composers = list(matched_composers_unique) + matched_forms = list(matched_forms_unique) + if len(matched_forms) == 1: + res["form"] = matched_forms[0] + if len(matched_composers) == 1: + res["composer"] = matched_composers[0] + + return res + + +def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> Dict[str, str]: + return {"abs_path": str(Path(str(mid.filename)).absolute())} + + +def get_metadata_fn( + metadata_process_name: str, +) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]]: + name_to_fn: Dict[ + str, + Callable[Concatenate[mido.MidiFile, MidiDictData, ...], Dict[str, str]], + ] = { + "composer_filename": meta_composer_filename, + "composer_metamsg": meta_composer_metamsg, + "form_filename": meta_form_filename, + "maestro_json": meta_maestro_json, + "abs_path": meta_abs_path, + } + + fn = name_to_fn.get(metadata_process_name, None) + if fn is None: + raise ValueError( + f"Error finding metadata function for {metadata_process_name}" + ) + else: + return fn + + +def test_max_programs(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: + """Returns false if midi_dict uses more than {max} programs.""" + present_programs = set( + map( + lambda msg: msg["data"], + midi_dict.instrument_msgs, + ) + ) + + if len(present_programs) <= max: + return True, len(present_programs) + else: + return False, len(present_programs) + + +def test_max_instruments(midi_dict: MidiDict, max: int) -> Tuple[bool, int]: + present_instruments = set( + map( + lambda msg: midi_dict.program_to_instrument[msg["data"]], + midi_dict.instrument_msgs, + ) + ) + + if len(present_instruments) <= max: + return True, len(present_instruments) + else: + return False, len(present_instruments) + + +def test_note_frequency( + midi_dict: MidiDict, max_per_second: float, min_per_second: float +) -> Tuple[bool, float]: + if not midi_dict.note_msgs: + return False, 0.0 + + num_notes = len(midi_dict.note_msgs) + total_duration_ms = get_duration_ms( + start_tick=midi_dict.note_msgs[0]["data"]["start"], + end_tick=midi_dict.note_msgs[-1]["data"]["end"], + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + if total_duration_ms == 0: + return False, 0.0 + + notes_per_second = (num_notes * 1e3) / total_duration_ms + + if notes_per_second < min_per_second or notes_per_second > max_per_second: + return False, notes_per_second + else: + return True, notes_per_second + + +def test_note_frequency_per_instrument( + midi_dict: MidiDict, max_per_second: float, min_per_second: float +) -> Tuple[bool, float]: + num_instruments = len( + set( + map( + lambda msg: midi_dict.program_to_instrument[msg["data"]], + midi_dict.instrument_msgs, + ) + ) + ) + + if not midi_dict.note_msgs: + return False, 0.0 + + num_notes = len(midi_dict.note_msgs) + total_duration_ms = get_duration_ms( + start_tick=midi_dict.note_msgs[0]["data"]["start"], + end_tick=midi_dict.note_msgs[-1]["data"]["end"], + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + if total_duration_ms == 0: + return False, 0.0 + + notes_per_second = (num_notes * 1e3) / total_duration_ms + + note_freq_per_instrument = notes_per_second / num_instruments + if ( + note_freq_per_instrument < min_per_second + or note_freq_per_instrument > max_per_second + ): + return False, note_freq_per_instrument + else: + return True, note_freq_per_instrument + + +def test_min_length( + midi_dict: MidiDict, min_seconds: int +) -> Tuple[bool, float]: + if not midi_dict.note_msgs: + return False, 0.0 + + total_duration_ms = get_duration_ms( + start_tick=midi_dict.note_msgs[0]["data"]["start"], + end_tick=midi_dict.note_msgs[-1]["data"]["end"], + tempo_msgs=midi_dict.tempo_msgs, + ticks_per_beat=midi_dict.ticks_per_beat, + ) + + if total_duration_ms / 1e3 < min_seconds: + return False, total_duration_ms / 1e3 + else: + return True, total_duration_ms / 1e3 + + +def get_test_fn( + test_name: str, +) -> Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]]: + name_to_fn: Dict[ + str, Callable[Concatenate[MidiDict, ...], Tuple[bool, Any]] + ] = { + "max_programs": test_max_programs, + "max_instruments": test_max_instruments, + "total_note_frequency": test_note_frequency, + "note_frequency_per_instrument": test_note_frequency_per_instrument, + "min_length": test_min_length, + } + + fn = name_to_fn.get(test_name, None) + if fn is None: + raise ValueError( + f"Error finding preprocessing function for {test_name}" + ) + else: + return fn diff --git a/ariautils/utils/__init__.py b/ariautils/utils/__init__.py new file mode 100644 index 0000000..eb1ccad --- /dev/null +++ b/ariautils/utils/__init__.py @@ -0,0 +1,39 @@ +"""Miscellaneous utilities.""" + +import json +import logging + +from importlib import resources +from typing import Dict, Any, cast + +from .config import load_config + + +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s]: [%(levelname)s] [%(name)s] %(message)s" + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger + + +def load_maestro_metadata_json() -> Dict[str, Any]: + """Loads MAESTRO metadata json .""" + with ( + resources.files("ariautils.config") + .joinpath("maestro_metadata.json") + .open("r") as f + ): + return cast(Dict[str, Any], json.load(f)) + + +__all__ = ["load_config", "load_maestro_metadata_json", "get_logger"] diff --git a/ariautils/utils/config.py b/ariautils/utils/config.py new file mode 100644 index 0000000..a4fd267 --- /dev/null +++ b/ariautils/utils/config.py @@ -0,0 +1,17 @@ +"""Includes functionality for loading config files.""" + +import os +import json + +from importlib import resources +from typing import Dict, Any, cast + + +def load_config() -> Dict[str, Any]: + """Returns a dictionary loaded from the config.json file.""" + with ( + resources.files("ariautils.config") + .joinpath("config.json") + .open("r") as f + ): + return cast(Dict[str, Any], json.load(f)) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ca7e2dc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,68 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "ariautils" +version = "0.0.1" +description = "" +authors = [{name = "Louis Bradshaw", email = "loua19@outlook.com"}] +requires-python = ">=3.11" +license = {text = "Apache-2.0"} +dependencies = [ + "mido", +] +readme = "README.md" +keywords = [] +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] + +[project.urls] +Repository = "https://github.com/EleutherAI/aria-utils" + +[project.optional-dependencies] +dev = [ + "mypy", + "black", + "pytest", +] + +[tool.setuptools] +packages = ["ariautils"] + +[tool.setuptools.package-data] +ariautils = ["config/*.json"] + +[tool.black] +line-length = 80 +target-version = ["py311"] +include = '\.pyi?$' + +[tool.mypy] +python_version = "3.11" +packages = ["ariautils", "tests"] +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = false +strict_equality = true +ignore_missing_imports = true +namespace_packages = true +explicit_package_bases = true + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q -s" +testpaths = ["tests"] +python_files = ["test_*.py"] \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/data/arabesque.mid b/tests/assets/data/arabesque.mid new file mode 100644 index 0000000000000000000000000000000000000000..38a0f85844d27d0a1a064115f248a34dc73f5f64 GIT binary patch literal 16975 zcmeHP33%JZm7g5gvV7!#Kw8qE5D5PyP8`d^@iDR`S^CS8WgV7e%SUXU%Wn(+# zpR@}o350ISO`(LVpTb6a-`;C=2&}ADS z+uiRAU%t%z-h1=@Z)VOUfnX`qMjLgd4V-b^i^PQP=Dl;*d8L{+@WGurY6C;-# z2ktnU{cH0W-D3m#@et52p8-PGb^={+ z8_+e2fo@X+?tT;K@!x^>+XopH%*ba{u;4nN6_=r@e-Xs^O+cIO2ig_{y2cB%e;?5I zb^$&8OQ1j9!pJ=T_l(S?_n`641##osKo?vEbklu6_v{3E@+F`@zs#ucBiAr0ta=S7 zF$uKgJ7});fq2)IKu>>(QPJ$rGAdf01PT=aZN3@kdMnT)p9Fe)kWsNVj3#y&6N|H0 z^5Xr2Ku=dPuVh{;qmqWlfX;pZ=*w$?p2Nn?S@Z;>Ic@mDoGYR%GUuTf6X(u@opT2d z0Ns8&^X9((Z;a;EXBo}gbt}*l=Q5gKdOD-|=iQIy@f}P&zTz{Cj=%g2pg)W-I-%pA z7@hE4Gou9+PcmBYg$;~OJmrs!$o+Po&h6*ICvtP`*j#R|+C7h(%f7#Wn@_(DbYT^E z+qa-O=N=5$S_IzaZV*Smfo5m{cpJ8W*ZVAp-4PJm4*W#CpdZG; zd%ys?|3>ic`61AqZ=v~S8AQIe0lfW~L*AF32k({@Al^Iykv-oA@rL8k>>dE`s{4Sh zut4OJEYPkO!MiX5wBv5@vTH#6hnqp%x){yJuLW;>5vDn^2WV(HcmrPt>i2+`dK{>G z1iVf&yH8&0&FoBh$%oNwIt9ed7XzKA&_6x`bWIq%y=_4I;I(;oAZGFoDBh!plDud3 zL*(WC81N?$^Yh9W<Z>#fW)9k-s5>EtS)m554{{3(dRLJ)h9*Gy-{fj))drt4+{ z-L53R{U*?(BN+F}%Zv(UBbgU0e4L2|tDXe0!3ETg#8mL{JAp3x1bBOI1^SjB=%+|w z1+T+0bIB$qnl&uB*_8lV_fZ~THjv58*W+X`-_Z@;V|&4SeHWv`6F$PI(2n^Pc4K0N z=OFtRe(rLJ-1nalc>z`zooFyBs=b|gMX3jYF2o#)_CE*o;?J2^Jii%@9m%J7Ba{_? zb_mME+v5al*O-+p4FmY0qQ3}mN=R)LXp9lKIgJ_kRrw>Bb;?W;zC^|f-Ke_0K5w-xAqJJ9z+Ko6_R__5o7o_q|vXK}a8 zf6)!%ZD*-wcw{w%G9xp0-L9E6X zOm%kvxyLZvUk()d6Hxs7K%IMlHsAta8tw!d$1U5m1($Twxk(VuAIE@QmjGS%RiGPA z0@|x2Z&h~ge-*rMG=X>5J~ZF?7>GZ(0_d?DfS$sdnV$Pspx>Mg-fPQ&{`fr5B;HyI z@>D&4LJboOPKHRqsY`*DYdoM}4a)k0wWxjz>J0EaUj#2!3e&!ioiAp9UcnL+y!l^@ z%mRySHXQ((=iqVXkKhT!ymA&$#XqBwDEZAE++NJVl^D>5ubY#LfCg7W@)KKucH9AU zu^H&yi1^9uU6@bIo^M4fN0gh&+L7i}|@HfPM!zHUII~j0$H#Sz$5kEIbMK zzrs>H#}ron3?fcEt`!F0427*fW7UPJTfrN;9O!JgP2olGy25Mn!Q0z`X5V)pawp=U z@W3yC9{v<~PlbTqK>biOvy@R$NrZVtAH4)<*$Y4wxP2DY{|dywHW1e@1R5^{I!~dC zTtK_=I#%?Bmw>*$0h0I60D21M7X5ZRh=0aIQL);w;*(Bfk>b-YU|z8eD_HFME>IH7 zQ#^7V&^Zr7^5PX}uE)z#@s~Y7cbOpaBe+fR%g7$;wU*defA3JgxL1h%`R59=!?@HZHJ=axsRe`xN-ZcvSZZM*nxxhwgjZ@_ zksb9&%_D>oRSY!%Y~hxgn}mG2=8~FA2$Ck1LERu_JnRrXYhipC=`hiVu1))t5fLKN zv~3|Z;dZtu*g^Ig0V`~^vaLQVtOvzwC0nN`ZY!B26vbs_qZGwq)f~Bgy)`>pFNED{ z*oCOI%32|hl4`A1h-#}=O`oaYzTK9!(Kj7yB^gs+Ey)T;4fpM}Y>V5T9VP1>E`!-_ zjg9nZ#HiP~@9<`AJPE%fAy<>+$u)D|5##~BFl~hpdR)^9wDS7st=xCF86*^F&yLbo z`?|4kJ-n2@B=^Iqtc}+=E!hY96&M8AsK78lMzT~NZ*I7W`w`uEw-fQ6>Dn|WGI*43 zOxxZm7b86ZU7O}f2AR~PXGh8X0y=}R&Yy|uGRll?rNr>*l1vWh@*vww4BCE=F5z;A z*ltqg7Mh1#i^pNwI!~MIgH7y5O}ar^-8!w1bm_dL(-k>Lg+pjAa%5(_Q;}cS{8NL4 zi$qDNPB-erPKQf3swMVaV&5h2VlDdbO}9&fw79(dH*b8M@H84^zst`QZ<9e&bvN^V zvO~CqHwa3IDeMCg;fZ8N?SB5A!_PkNYJ%}iY`ibKI85b z-hLiZ-^sr1>S8h8$1dSp&k7yuSlGReWynfzw`dyRICpj9qdhFc4$&h5gS@xyBujde zEW?hJ62VQJ$vi2R^!2g~yIijb53v_``&hE6k7d{$`b1K}`amCBcBc zJxS3)pHu!sK|y_=p0)9`b;*lD)00zu6#+C=mn4^jZ_pyjoCCs@**U7DsADSov_t14 zQ%6M7D1~2AL}o{MH4tZhqEe!g!hZMw@*zrHKKc$BPxc%P#)vLyRzaQf0xvH8ElJmr z&%tqU4mxt4WNVepG=+oaxOGl*2qEIa#dP+LTAibdlPuXzGg1cAFv1RKB>R~tNt7uB zt7LYRmVrl$9SM5^j@b@h>x2ibW54ip!>Qp@t`t134}PS)V?Cz2fy3Vi@AGctQT`#O z1mQ=KjNwfQe~MQsypa$|Z#+X(O<0$&TRWhAUMnqRt@}Y{xP|+0B zd5Y7k9xv5eIf`3HrUOOgQ$=wgtEzrIxe2?Vjw)F`D$q7`+2_E3}6wZXM`XI9>bhfBN*E+~99X5gz)nPR#>SCfODu5M5 z#hvO;U66cq_o8oGdFhYF2<5V2a_e zvdbNO^q4U>_vleL_2}VBjgKFC^7v))QoA0uNzFzlm(=Lsntl~i{rBdsmF%-BBTk*J z$c}5p0~V{US%s+dYL%qd$H`1x!fFFZZ_P}p3Taj3+hL~EwP+L+t|*GxID21RD^YfR zSc(p<%k1LZx~Rt>>$@I<#HDszP#qG1#&13p2>iH11qtWW={Hv;iU%LM9AkuHS4o=P zg~|E`+~+)~nZ2l)ee57jO&lSvsFe-^hLmX0M|~I$hqbvE!(*imqgiw*D2GTtPEj4trLx z{go>v9|5J3HnegT_ibfNFJHqlrRB1e`c}wSV7@lYw-WPR%Y9oF8WenI~D&@8T+PX(2=w|L>ueMl-6<&?qyOHCHJN;^p913iS=q(RmbtmJ1-Hkj?;m%?oW z^nEf%lZ4qWTKB2PQ&K>q9g^!z8p^RLX^9OENv4w~Qb5FQ1B_`YXg)LrwJS9sBABBY{Q z*t&V?NcWOXv8t0NPkL+#Cc0v{vqdE@P(;!YdzhB~NWRa`xs-nRk`{GoQJWT(RN+iB zz*v?R*0hAd($yvEx&nCZ3s=^og};Ny4>h4&=HtyNxp~ zo#HCTQK&hlWptcT+Z^wmSgMd@RKiiIQ^}W0oY~V7k;*dD<(16Q*1tlc=1}(7IWbn& zp)#mPHBgUgz`^u#C)=!2e0c-xSBc(@MCw7!fhxh~g}j$#RyCsL@If;whH5`*jwa5k z)d75^8D9yq%<3R&ju7{&!l*eSELq;d+EgV`6GhE|O2QUr$(mNw9IY(rZbQwHK+Vz4 z6h}L!6upCL4)ThaR^7x!V_g83X8ioJt{KO3km(dywuGjx*Y*$>C62Jh={UlDDyunj zlM8lfl~F#8YNNP<$5=ycjQf>wrjyH-CpB9upIg;!tXcg?r8>cVTRYQfN@ac8nH?JW zaaAWS^qr{9ySRVmI;L|7ZP3cQQQ~yt{Of_8N$%4pc$H~mHfd$OTo~JWS#wn%?Cj%y zS(@pb`_nj3(>PGUs_ccr6f3Mv(x=Z4FAZya(1&u3W!7+&#>>DfYz^Z@i`}5g%7=r^%Gs^f z%5{gjQ@O450=&~f?>LG_Dh}S4^3KZLmyQefr&B}X-rWC5BOK>uQw}lbeKgx67&GzYqW9lpP;j?J*6<8ApCfMm z_AmUMBzKk{VOIXI_CB|iNvd299z0@o?%-K+_`j>SY)qX^e;flN6T_*&q1AmO!_!8c zo*T7r>L1-(-u=(v`(}vv~9T0;6Xc^VguwfGuhfp%V+MxV|M%jVk%|8C)dLgo$G-p}sW None: + self.logger = get_logger(__name__ + ".TestMidiDict") + + def test_load(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict = MidiDict.from_midi(load_path) + + self.logger.info(f"Num meta_msgs: {len(midi_dict.meta_msgs)}") + self.logger.info(f"Num tempo_msgs: {len(midi_dict.tempo_msgs)}") + self.logger.info(f"Num pedal_msgs: {len(midi_dict.pedal_msgs)}") + self.logger.info( + f"Num instrument_msgs: {len(midi_dict.instrument_msgs)}" + ) + self.logger.info(f"Num note_msgs: {len(midi_dict.note_msgs)}") + self.logger.info(f"ticks_per_beat: {midi_dict.ticks_per_beat}") + self.logger.info(f"metadata: {midi_dict.metadata}") + + def test_save(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("arabesque.mid") + midi_dict = MidiDict.from_midi(mid_path=load_path) + midi_dict.to_midi().save(save_path) + + def test_resolve_pedal(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath( + "arabesque_pedal_resolved.mid" + ) + midi_dict = MidiDict.from_midi(mid_path=load_path).resolve_pedal() + midi_dict.to_midi().save(save_path) + + def test_remove_redundant_pedals(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath( + "arabesque_remove_redundant_pedals.mid" + ) + midi_dict = MidiDict.from_midi(mid_path=load_path) + self.logger.info( + f"Num pedal_msgs before remove_redundant_pedals: {len(midi_dict.pedal_msgs)}" + ) + + midi_dict_adj_resolve = ( + MidiDict.from_midi(mid_path=load_path) + .resolve_pedal() + .remove_redundant_pedals() + ) + midi_dict_resolve_adj = ( + MidiDict.from_midi(mid_path=load_path) + .remove_redundant_pedals() + .resolve_pedal() + ) + + self.logger.info( + f"Num pedal_msgs after remove_redundant_pedals: {len(midi_dict_adj_resolve.pedal_msgs)}" + ) + self.assertEqual( + len(midi_dict_adj_resolve.pedal_msgs), + len(midi_dict_resolve_adj.pedal_msgs), + ) + + for msg_1, msg_2 in zip( + midi_dict_adj_resolve.note_msgs, midi_dict_resolve_adj.note_msgs + ): + self.assertDictEqual(msg_1, msg_2) + + midi_dict_adj_resolve.to_midi().save(save_path)