diff --git a/README.md b/README.md index 45a7e1f..e27fd30 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,13 @@ -# aria-utils +

+

+ █████╗ ██████╗ ██╗ █████╗     ██╗   ██╗████████╗██╗██╗     ███████╗
+██╔══██╗██╔══██╗██║██╔══██╗    ██║   ██║╚══██╔══╝██║██║     ██╔════╝
+███████║██████╔╝██║███████║    ██║   ██║   ██║   ██║██║     ███████╗
+██╔══██║██╔══██╗██║██╔══██║    ██║   ██║   ██║   ██║██║     ╚════██║
+██║  ██║██║  ██║██║██║  ██║    ╚██████╔╝   ██║   ██║███████╗███████║
+╚═╝  ╚═╝╚═╝  ╚═╝╚═╝╚═╝  ╚═╝     ╚═════╝    ╚═╝   ╚═╝╚══════╝╚══════╝
+
+

An extremely lightweight and simple library for pre-processing and tokenizing MIDI files. diff --git a/ariautils/midi.py b/ariautils/midi.py index 1813bdf..eab984a 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -4,6 +4,7 @@ import os import json import hashlib +import copy import unicodedata import mido @@ -267,6 +268,8 @@ def _build_pedal_intervals(self) -> dict[int, list[list[int]]]: return channel_to_pedal_intervals + # TODO: This function might not behave correctly when acting on degenerate + # MidiDict objects. def resolve_overlaps(self) -> "MidiDict": """Resolves any note overlaps (inplace) between notes with the same pitch and channel. This is achieved by converting a pair of notes with @@ -465,7 +468,7 @@ def _process_channel_pedals(channel: int) -> None: return self - def remove_instruments(self, config: dict) -> "MidiDict": + def remove_instruments(self, remove_instruments: dict) -> "MidiDict": """Removes all messages with instruments specified in config at: data.preprocessing.remove_instruments @@ -477,7 +480,7 @@ def remove_instruments(self, config: dict) -> "MidiDict": programs_to_remove = [ i for i in range(1, 127 + 1) - if config[self.program_to_instrument[i]] is True + if remove_instruments[self.program_to_instrument[i]] is True ] channels_to_remove = [ msg["channel"] @@ -697,6 +700,8 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: return midi_dict_data +# POTENTIAL BUG: MidiDict can represent overlapping notes on the same channel. +# When converted to a MIDI files, is this handled correctly? Verify this. def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: """Converts MIDI information from dictionary form into a mido.MidiFile. @@ -1151,3 +1156,153 @@ def get_test_fn( ) else: return fn + + +def normalize_midi_dict( + midi_dict: MidiDict, + ignore_instruments: dict[str, bool], + instrument_programs: dict[str, int], + time_step_ms: int, + max_duration_ms: int, + drum_velocity: int, + quantize_velocity_fn: Callable[[int], int], +) -> MidiDict: + """Reorganizes a MidiDict to enable consistent comparisons. + + This function normalizes a MidiDict for testing purposes, ensuring that + equivalent MidiDicts can be compared on a literal basis. It is useful + for validating tokenization or other transformations applied to MIDI + representations. + + Args: + midi_dict (MidiDict): The input MIDI representation to be normalized. + ignore_instruments (dict[str, bool]): A mapping of instrument names to + booleans indicating whether they should be ignored during + normalization. + instrument_programs (dict[str, int]): A mapping of instrument names to + MIDI program numbers used to assign instruments in the output. + time_step_ms (int): The duration of the minimum time step in + milliseconds, used for quantizing note timing. + max_duration_ms (int): The maximum allowable duration for a note in + milliseconds. Notes exceeding this duration will be truncated. + drum_velocity (int): The fixed velocity value assigned to drum notes. + quantize_velocity_fn (Callable[[int], int]): A function that maps input + velocity values to quantized output values. + + Returns: + MidiDict: A normalized version of the input `midi_dict`, reorganized + for consistent and comparable structure. + """ + + def _create_channel_mappings( + midi_dict: MidiDict, instruments: list[str] + ) -> tuple[dict[str, int], dict[int, str]]: + + new_instrument_to_channel = { + instrument: (9 if instrument == "drum" else idx + (idx >= 9)) + for idx, instrument in enumerate(instruments) + } + + old_channel_to_instrument = { + msg["channel"]: midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + } + old_channel_to_instrument[9] = "drum" + + return new_instrument_to_channel, old_channel_to_instrument + + def _create_instrument_messages( + instrument_programs: dict[str, int], + instrument_to_channel: dict[str, int], + ) -> list[InstrumentMessage]: + + return [ + { + "type": "instrument", + "data": instrument_programs[k] if k != "drum" else 0, + "tick": 0, + "channel": v, + } + for k, v in instrument_to_channel.items() + ] + + def _normalize_note_messages( + midi_dict: MidiDict, + old_channel_to_instrument: dict[int, str], + new_instrument_to_channel: dict[str, int], + time_step_ms: int, + max_duration_ms: int, + drum_velocity: int, + quantize_velocity_fn: Callable[[int], int], + ) -> list[NoteMessage]: + + def _quantize_time(_n: int) -> int: + return round(_n / time_step_ms) * time_step_ms + + note_msgs: list[NoteMessage] = [] + for msg in midi_dict.note_msgs: + msg_channel = msg["channel"] + instrument = old_channel_to_instrument[msg_channel] + new_msg_channel = new_instrument_to_channel[instrument] + + start_tick = _quantize_time( + midi_dict.tick_to_ms(msg["data"]["start"]) + ) + end_tick = _quantize_time(midi_dict.tick_to_ms(msg["data"]["end"])) + velocity = quantize_velocity_fn(msg["data"]["velocity"]) + + new_msg = copy.deepcopy(msg) + new_msg["channel"] = new_msg_channel + new_msg["tick"] = start_tick + new_msg["data"]["start"] = start_tick + + if new_msg_channel != 9: + new_msg["data"]["end"] = min( + start_tick + max_duration_ms, end_tick + ) + new_msg["data"]["velocity"] = velocity + else: + new_msg["data"]["end"] = start_tick + time_step_ms + new_msg["data"]["velocity"] = drum_velocity + + note_msgs.append(new_msg) + + return note_msgs + + midi_dict = copy.deepcopy(midi_dict) + midi_dict.remove_instruments(remove_instruments=ignore_instruments) + midi_dict.resolve_pedal() + midi_dict.pedal_msgs = [] + + instruments = [k for k, v in ignore_instruments.items() if not v] + ["drum"] + new_instrument_to_channel, old_channel_to_instrument = ( + _create_channel_mappings( + midi_dict, + instruments, + ) + ) + + instrument_msgs = _create_instrument_messages( + instrument_programs, + new_instrument_to_channel, + ) + + note_msgs = _normalize_note_messages( + midi_dict=midi_dict, + old_channel_to_instrument=old_channel_to_instrument, + new_instrument_to_channel=new_instrument_to_channel, + time_step_ms=time_step_ms, + max_duration_ms=max_duration_ms, + drum_velocity=drum_velocity, + quantize_velocity_fn=quantize_velocity_fn, + ) + + return MidiDict( + meta_msgs=[], + tempo_msgs=[{"type": "tempo", "data": 500000, "tick": 0}], + pedal_msgs=[], + instrument_msgs=instrument_msgs, + note_msgs=note_msgs, + ticks_per_beat=500, + metadata={}, + ) diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py index 792d910..4596a44 100644 --- a/ariautils/tokenizer/__init__.py +++ b/ariautils/tokenizer/__init__.py @@ -1,5 +1,6 @@ """Includes Tokenizers and pre-processing utilities.""" -from ariautils.tokenizer._base import Tokenizer +from ._base import Tokenizer +from .absolute import AbsTokenizer -__all__ = ["Tokenizer"] +__all__ = ["Tokenizer", "AbsTokenizer"] diff --git a/ariautils/tokenizer/_base.py b/ariautils/tokenizer/_base.py index 35e51ab..f45fc4a 100644 --- a/ariautils/tokenizer/_base.py +++ b/ariautils/tokenizer/_base.py @@ -17,12 +17,7 @@ class Tokenizer: - """Abstract Tokenizer class for tokenizing MidiDict objects. - - Args: - return_tensors (bool, optional): If True, encode will return tensors. - Defaults to False. - """ + """Abstract Tokenizer class for tokenizing MidiDict objects.""" def __init__( self, diff --git a/ariautils/tokenizer/absolute.py b/ariautils/tokenizer/absolute.py index 6ac465b..dbd9878 100644 --- a/ariautils/tokenizer/absolute.py +++ b/ariautils/tokenizer/absolute.py @@ -6,7 +6,7 @@ import copy from collections import defaultdict -from typing import Final, Callable +from typing import Final, Callable, Any from ariautils.midi import ( MidiDict, @@ -26,25 +26,31 @@ # TODO: # - Add asserts to the tokenization / detokenization for user error +# - Need to add a tokenization or MidiDict check of how to resolve different +# channels, with the same instrument, have overlaping notes +# - There are tons of edge cases here e.g., what if there are two indetical notes? +# on different channels. class AbsTokenizer(Tokenizer): """MidiDict tokenizer implemented with absolute onset timings. - The tokenizer processes MIDI files in 5000ms segments, with each segment separated by - a special token. Within each segment, note timings are represented relative to the - segment start. + The tokenizer processes MIDI files in 5000ms segments, with each segment + separated by a special token. Within each segment, note timings are + represented relative to the segment start. Tokenization Schema: For non-percussion instruments: - Each note is represented by three consecutive tokens: - 1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, and velocity + 1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, + and velocity 2. [onset]: Absolute time in milliseconds from segment start 3. [duration]: Note duration in milliseconds For percussion instruments: - Each note is represented by two consecutive tokens: - 1. [drum, note_number]: Percussion instrument and MIDI note number + 1. [drum, note_number]: Percussion instrument and MIDI note + number 2. [onset]: Absolute time in milliseconds from segment start Notes: @@ -140,7 +146,9 @@ def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]: def _quantize_dur(self, time: int) -> int: # This function will return values res >= 0 (inc. 0) - return self._find_closest_int(time, self.dur_time_quantizations) + dur = self._find_closest_int(time, self.dur_time_quantizations) + + return dur if dur != 0 else self.time_step def _quantize_onset(self, time: int) -> int: # This function will return values res >= 0 (inc. 0) @@ -337,8 +345,6 @@ def _tokenize_midi_dict( curr_time_since_onset % self.abs_time_step ) _note_duration = self._quantize_dur(_note_duration) - if _note_duration == 0: - _note_duration = self.time_step tokenized_seq.append((_instrument, _pitch, _velocity)) tokenized_seq.append(("onset", _note_onset)) @@ -349,6 +355,28 @@ def _tokenize_midi_dict( unformatted_seq=tokenized_seq, ) + def tokenize( + self, + midi_dict: MidiDict, + remove_preceding_silence: bool = True, + **kwargs: Any, + ) -> list[Token]: + """Tokenizes a MidiDict object into a sequence. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + remove_preceding_silence (bool): If true starts the sequence at + onset=0ms by removing preceding silence. Defaults to False. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + return self._tokenize_midi_dict( + midi_dict=midi_dict, + remove_preceding_silence=remove_preceding_silence, + ) + def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: # NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to # skip converting between ticks and ms @@ -534,6 +562,18 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: metadata={}, ) + def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict: + """Detokenizes a MidiDict object. + + Args: + tokenized_seq (list): The sequence of tokens to detokenize. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + return self._detokenize_midi_dict(tokenized_seq=tokenized_seq) + def export_pitch_aug( self, aug_range: int ) -> Callable[[list[Token]], list[Token]]: diff --git a/tests/assets/data/basic.mid b/tests/assets/data/basic.mid new file mode 100644 index 0000000..c44fe36 Binary files /dev/null and b/tests/assets/data/basic.mid differ diff --git a/tests/assets/data/pop.mid b/tests/assets/data/pop.mid new file mode 100644 index 0000000..be83c69 Binary files /dev/null and b/tests/assets/data/pop.mid differ diff --git a/tests/test_midi.py b/tests/test_midi.py index 5d6aa02..70cd57a 100644 --- a/tests/test_midi.py +++ b/tests/test_midi.py @@ -1,3 +1,5 @@ +""""Tests for MidiDict.""" + import unittest import tempfile import shutil diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..8c7ed13 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,112 @@ +"""Tests for tokenizers.""" + +import unittest +import copy + +from importlib import resources +from pathlib import Path +from typing import Final + +from ariautils.midi import MidiDict, normalize_midi_dict +from ariautils.tokenizer import AbsTokenizer +from ariautils.utils import get_logger + + +TEST_DATA_DIRECTORY: Final[Path] = Path( + str(resources.files("tests").joinpath("assets", "data")) +) +RESULTS_DATA_DIRECTORY: Final[Path] = Path( + str(resources.files("tests").joinpath("assets", "results")) +) + + +class TestAbsTokenizer(unittest.TestCase): + def setUp(self) -> None: + self.logger = get_logger(__name__ + ".TestAbsTokenizer") + + def test_normalize_midi_dict(self) -> None: + def _test_normalize_midi_dict( + _load_path: Path, _save_path: Path + ) -> None: + tokenizer = AbsTokenizer() + midi_dict = MidiDict.from_midi(_load_path) + midi_dict_copy = copy.deepcopy(midi_dict) + + normalized_midi_dict = normalize_midi_dict( + midi_dict=midi_dict, + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + normalized_twice_midi_dict = normalize_midi_dict( + normalized_midi_dict, + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + self.assertDictEqual( + normalized_midi_dict.get_msg_dict(), + normalized_twice_midi_dict.get_msg_dict(), + ) + self.assertDictEqual( + midi_dict.get_msg_dict(), + midi_dict_copy.get_msg_dict(), + ) + normalized_midi_dict.to_midi().save(_save_path) + + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("arabesque_norm.mid") + _test_normalize_midi_dict(load_path, save_path) + load_path = TEST_DATA_DIRECTORY.joinpath("pop.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("pop_norm.mid") + _test_normalize_midi_dict(load_path, save_path) + load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("basic_norm.mid") + _test_normalize_midi_dict(load_path, save_path) + + def test_tokenize_detokenize(self) -> None: + def _test_tokenize_detokenize(_load_path: Path) -> None: + tokenizer = AbsTokenizer() + midi_dict = MidiDict.from_midi(_load_path) + + midi_dict_1 = normalize_midi_dict( + midi_dict=midi_dict, + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + + midi_dict_2 = normalize_midi_dict( + midi_dict=tokenizer.detokenize( + tokenizer.tokenize( + midi_dict_1, remove_preceding_silence=False + ) + ), + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + + self.assertDictEqual( + midi_dict_1.get_msg_dict(), + midi_dict_2.get_msg_dict(), + ) + + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + _test_tokenize_detokenize(_load_path=load_path) + load_path = TEST_DATA_DIRECTORY.joinpath("pop.mid") + _test_tokenize_detokenize(_load_path=load_path) + load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid") + _test_tokenize_detokenize(_load_path=load_path)