diff --git a/README.md b/README.md index 45a7e1f..e27fd30 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,13 @@ -# aria-utils +
+
+ █████╗ ██████╗ ██╗ █████╗ ██╗ ██╗████████╗██╗██╗ ███████╗ +██╔══██╗██╔══██╗██║██╔══██╗ ██║ ██║╚══██╔══╝██║██║ ██╔════╝ +███████║██████╔╝██║███████║ ██║ ██║ ██║ ██║██║ ███████╗ +██╔══██║██╔══██╗██║██╔══██║ ██║ ██║ ██║ ██║██║ ╚════██║ +██║ ██║██║ ██║██║██║ ██║ ╚██████╔╝ ██║ ██║███████╗███████║ +╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚══════╝╚══════╝ ++ An extremely lightweight and simple library for pre-processing and tokenizing MIDI files. diff --git a/ariautils/midi.py b/ariautils/midi.py index 1813bdf..eab984a 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -4,6 +4,7 @@ import os import json import hashlib +import copy import unicodedata import mido @@ -267,6 +268,8 @@ def _build_pedal_intervals(self) -> dict[int, list[list[int]]]: return channel_to_pedal_intervals + # TODO: This function might not behave correctly when acting on degenerate + # MidiDict objects. def resolve_overlaps(self) -> "MidiDict": """Resolves any note overlaps (inplace) between notes with the same pitch and channel. This is achieved by converting a pair of notes with @@ -465,7 +468,7 @@ def _process_channel_pedals(channel: int) -> None: return self - def remove_instruments(self, config: dict) -> "MidiDict": + def remove_instruments(self, remove_instruments: dict) -> "MidiDict": """Removes all messages with instruments specified in config at: data.preprocessing.remove_instruments @@ -477,7 +480,7 @@ def remove_instruments(self, config: dict) -> "MidiDict": programs_to_remove = [ i for i in range(1, 127 + 1) - if config[self.program_to_instrument[i]] is True + if remove_instruments[self.program_to_instrument[i]] is True ] channels_to_remove = [ msg["channel"] @@ -697,6 +700,8 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: return midi_dict_data +# POTENTIAL BUG: MidiDict can represent overlapping notes on the same channel. +# When converted to a MIDI files, is this handled correctly? Verify this. def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: """Converts MIDI information from dictionary form into a mido.MidiFile. @@ -1151,3 +1156,153 @@ def get_test_fn( ) else: return fn + + +def normalize_midi_dict( + midi_dict: MidiDict, + ignore_instruments: dict[str, bool], + instrument_programs: dict[str, int], + time_step_ms: int, + max_duration_ms: int, + drum_velocity: int, + quantize_velocity_fn: Callable[[int], int], +) -> MidiDict: + """Reorganizes a MidiDict to enable consistent comparisons. + + This function normalizes a MidiDict for testing purposes, ensuring that + equivalent MidiDicts can be compared on a literal basis. It is useful + for validating tokenization or other transformations applied to MIDI + representations. + + Args: + midi_dict (MidiDict): The input MIDI representation to be normalized. + ignore_instruments (dict[str, bool]): A mapping of instrument names to + booleans indicating whether they should be ignored during + normalization. + instrument_programs (dict[str, int]): A mapping of instrument names to + MIDI program numbers used to assign instruments in the output. + time_step_ms (int): The duration of the minimum time step in + milliseconds, used for quantizing note timing. + max_duration_ms (int): The maximum allowable duration for a note in + milliseconds. Notes exceeding this duration will be truncated. + drum_velocity (int): The fixed velocity value assigned to drum notes. + quantize_velocity_fn (Callable[[int], int]): A function that maps input + velocity values to quantized output values. + + Returns: + MidiDict: A normalized version of the input `midi_dict`, reorganized + for consistent and comparable structure. + """ + + def _create_channel_mappings( + midi_dict: MidiDict, instruments: list[str] + ) -> tuple[dict[str, int], dict[int, str]]: + + new_instrument_to_channel = { + instrument: (9 if instrument == "drum" else idx + (idx >= 9)) + for idx, instrument in enumerate(instruments) + } + + old_channel_to_instrument = { + msg["channel"]: midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + } + old_channel_to_instrument[9] = "drum" + + return new_instrument_to_channel, old_channel_to_instrument + + def _create_instrument_messages( + instrument_programs: dict[str, int], + instrument_to_channel: dict[str, int], + ) -> list[InstrumentMessage]: + + return [ + { + "type": "instrument", + "data": instrument_programs[k] if k != "drum" else 0, + "tick": 0, + "channel": v, + } + for k, v in instrument_to_channel.items() + ] + + def _normalize_note_messages( + midi_dict: MidiDict, + old_channel_to_instrument: dict[int, str], + new_instrument_to_channel: dict[str, int], + time_step_ms: int, + max_duration_ms: int, + drum_velocity: int, + quantize_velocity_fn: Callable[[int], int], + ) -> list[NoteMessage]: + + def _quantize_time(_n: int) -> int: + return round(_n / time_step_ms) * time_step_ms + + note_msgs: list[NoteMessage] = [] + for msg in midi_dict.note_msgs: + msg_channel = msg["channel"] + instrument = old_channel_to_instrument[msg_channel] + new_msg_channel = new_instrument_to_channel[instrument] + + start_tick = _quantize_time( + midi_dict.tick_to_ms(msg["data"]["start"]) + ) + end_tick = _quantize_time(midi_dict.tick_to_ms(msg["data"]["end"])) + velocity = quantize_velocity_fn(msg["data"]["velocity"]) + + new_msg = copy.deepcopy(msg) + new_msg["channel"] = new_msg_channel + new_msg["tick"] = start_tick + new_msg["data"]["start"] = start_tick + + if new_msg_channel != 9: + new_msg["data"]["end"] = min( + start_tick + max_duration_ms, end_tick + ) + new_msg["data"]["velocity"] = velocity + else: + new_msg["data"]["end"] = start_tick + time_step_ms + new_msg["data"]["velocity"] = drum_velocity + + note_msgs.append(new_msg) + + return note_msgs + + midi_dict = copy.deepcopy(midi_dict) + midi_dict.remove_instruments(remove_instruments=ignore_instruments) + midi_dict.resolve_pedal() + midi_dict.pedal_msgs = [] + + instruments = [k for k, v in ignore_instruments.items() if not v] + ["drum"] + new_instrument_to_channel, old_channel_to_instrument = ( + _create_channel_mappings( + midi_dict, + instruments, + ) + ) + + instrument_msgs = _create_instrument_messages( + instrument_programs, + new_instrument_to_channel, + ) + + note_msgs = _normalize_note_messages( + midi_dict=midi_dict, + old_channel_to_instrument=old_channel_to_instrument, + new_instrument_to_channel=new_instrument_to_channel, + time_step_ms=time_step_ms, + max_duration_ms=max_duration_ms, + drum_velocity=drum_velocity, + quantize_velocity_fn=quantize_velocity_fn, + ) + + return MidiDict( + meta_msgs=[], + tempo_msgs=[{"type": "tempo", "data": 500000, "tick": 0}], + pedal_msgs=[], + instrument_msgs=instrument_msgs, + note_msgs=note_msgs, + ticks_per_beat=500, + metadata={}, + ) diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py index 792d910..4596a44 100644 --- a/ariautils/tokenizer/__init__.py +++ b/ariautils/tokenizer/__init__.py @@ -1,5 +1,6 @@ """Includes Tokenizers and pre-processing utilities.""" -from ariautils.tokenizer._base import Tokenizer +from ._base import Tokenizer +from .absolute import AbsTokenizer -__all__ = ["Tokenizer"] +__all__ = ["Tokenizer", "AbsTokenizer"] diff --git a/ariautils/tokenizer/_base.py b/ariautils/tokenizer/_base.py index 35e51ab..f45fc4a 100644 --- a/ariautils/tokenizer/_base.py +++ b/ariautils/tokenizer/_base.py @@ -17,12 +17,7 @@ class Tokenizer: - """Abstract Tokenizer class for tokenizing MidiDict objects. - - Args: - return_tensors (bool, optional): If True, encode will return tensors. - Defaults to False. - """ + """Abstract Tokenizer class for tokenizing MidiDict objects.""" def __init__( self, diff --git a/ariautils/tokenizer/absolute.py b/ariautils/tokenizer/absolute.py index 6ac465b..dbd9878 100644 --- a/ariautils/tokenizer/absolute.py +++ b/ariautils/tokenizer/absolute.py @@ -6,7 +6,7 @@ import copy from collections import defaultdict -from typing import Final, Callable +from typing import Final, Callable, Any from ariautils.midi import ( MidiDict, @@ -26,25 +26,31 @@ # TODO: # - Add asserts to the tokenization / detokenization for user error +# - Need to add a tokenization or MidiDict check of how to resolve different +# channels, with the same instrument, have overlaping notes +# - There are tons of edge cases here e.g., what if there are two indetical notes? +# on different channels. class AbsTokenizer(Tokenizer): """MidiDict tokenizer implemented with absolute onset timings. - The tokenizer processes MIDI files in 5000ms segments, with each segment separated by - a special