Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AbsTokenizer tests #4

Merged
merged 16 commits into from
Nov 21, 2024
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# aria-utils
<p align="center">
<pre>
█████╗ ██████╗ ██╗ █████╗ ██╗ ██╗████████╗██╗██╗ ███████╗
██╔══██╗██╔══██╗██║██╔══██╗ ██║ ██║╚══██╔══╝██║██║ ██╔════╝
███████║██████╔╝██║███████║ ██║ ██║ ██║ ██║██║ ███████╗
██╔══██║██╔══██╗██║██╔══██║ ██║ ██║ ██║ ██║██║ ╚════██║
██║ ██║██║ ██║██║██║ ██║ ╚██████╔╝ ██║ ██║███████╗███████║
╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚══════╝╚══════╝
</pre>
</p>

An extremely lightweight and simple library for pre-processing and tokenizing MIDI files.

Expand Down
159 changes: 157 additions & 2 deletions ariautils/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import json
import hashlib
import copy
import unicodedata
import mido

Expand Down Expand Up @@ -267,6 +268,8 @@ def _build_pedal_intervals(self) -> dict[int, list[list[int]]]:

return channel_to_pedal_intervals

# TODO: This function might not behave correctly when acting on degenerate
# MidiDict objects.
def resolve_overlaps(self) -> "MidiDict":
"""Resolves any note overlaps (inplace) between notes with the same
pitch and channel. This is achieved by converting a pair of notes with
Expand Down Expand Up @@ -465,7 +468,7 @@ def _process_channel_pedals(channel: int) -> None:

return self

def remove_instruments(self, config: dict) -> "MidiDict":
def remove_instruments(self, remove_instruments: dict) -> "MidiDict":
"""Removes all messages with instruments specified in config at:

data.preprocessing.remove_instruments
Expand All @@ -477,7 +480,7 @@ def remove_instruments(self, config: dict) -> "MidiDict":
programs_to_remove = [
i
for i in range(1, 127 + 1)
if config[self.program_to_instrument[i]] is True
if remove_instruments[self.program_to_instrument[i]] is True
]
channels_to_remove = [
msg["channel"]
Expand Down Expand Up @@ -697,6 +700,8 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData:
return midi_dict_data


# POTENTIAL BUG: MidiDict can represent overlapping notes on the same channel.
# When converted to a MIDI files, is this handled correctly? Verify this.
def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile:
"""Converts MIDI information from dictionary form into a mido.MidiFile.

Expand Down Expand Up @@ -1151,3 +1156,153 @@ def get_test_fn(
)
else:
return fn


def normalize_midi_dict(
midi_dict: MidiDict,
ignore_instruments: dict[str, bool],
instrument_programs: dict[str, int],
time_step_ms: int,
max_duration_ms: int,
drum_velocity: int,
quantize_velocity_fn: Callable[[int], int],
) -> MidiDict:
"""Reorganizes a MidiDict to enable consistent comparisons.

This function normalizes a MidiDict for testing purposes, ensuring that
equivalent MidiDicts can be compared on a literal basis. It is useful
for validating tokenization or other transformations applied to MIDI
representations.

Args:
midi_dict (MidiDict): The input MIDI representation to be normalized.
ignore_instruments (dict[str, bool]): A mapping of instrument names to
booleans indicating whether they should be ignored during
normalization.
instrument_programs (dict[str, int]): A mapping of instrument names to
MIDI program numbers used to assign instruments in the output.
time_step_ms (int): The duration of the minimum time step in
milliseconds, used for quantizing note timing.
max_duration_ms (int): The maximum allowable duration for a note in
milliseconds. Notes exceeding this duration will be truncated.
drum_velocity (int): The fixed velocity value assigned to drum notes.
quantize_velocity_fn (Callable[[int], int]): A function that maps input
velocity values to quantized output values.

Returns:
MidiDict: A normalized version of the input `midi_dict`, reorganized
for consistent and comparable structure.
"""

def _create_channel_mappings(
midi_dict: MidiDict, instruments: list[str]
) -> tuple[dict[str, int], dict[int, str]]:

new_instrument_to_channel = {
instrument: (9 if instrument == "drum" else idx + (idx >= 9))
for idx, instrument in enumerate(instruments)
}

old_channel_to_instrument = {
msg["channel"]: midi_dict.program_to_instrument[msg["data"]]
for msg in midi_dict.instrument_msgs
}
old_channel_to_instrument[9] = "drum"

return new_instrument_to_channel, old_channel_to_instrument

def _create_instrument_messages(
instrument_programs: dict[str, int],
instrument_to_channel: dict[str, int],
) -> list[InstrumentMessage]:

return [
{
"type": "instrument",
"data": instrument_programs[k] if k != "drum" else 0,
"tick": 0,
"channel": v,
}
for k, v in instrument_to_channel.items()
]

def _normalize_note_messages(
midi_dict: MidiDict,
old_channel_to_instrument: dict[int, str],
new_instrument_to_channel: dict[str, int],
time_step_ms: int,
max_duration_ms: int,
drum_velocity: int,
quantize_velocity_fn: Callable[[int], int],
) -> list[NoteMessage]:

def _quantize_time(_n: int) -> int:
return round(_n / time_step_ms) * time_step_ms

note_msgs: list[NoteMessage] = []
for msg in midi_dict.note_msgs:
msg_channel = msg["channel"]
instrument = old_channel_to_instrument[msg_channel]
new_msg_channel = new_instrument_to_channel[instrument]

start_tick = _quantize_time(
midi_dict.tick_to_ms(msg["data"]["start"])
)
end_tick = _quantize_time(midi_dict.tick_to_ms(msg["data"]["end"]))
velocity = quantize_velocity_fn(msg["data"]["velocity"])

new_msg = copy.deepcopy(msg)
new_msg["channel"] = new_msg_channel
new_msg["tick"] = start_tick
new_msg["data"]["start"] = start_tick

if new_msg_channel != 9:
new_msg["data"]["end"] = min(
start_tick + max_duration_ms, end_tick
)
new_msg["data"]["velocity"] = velocity
else:
new_msg["data"]["end"] = start_tick + time_step_ms
new_msg["data"]["velocity"] = drum_velocity

note_msgs.append(new_msg)

return note_msgs

midi_dict = copy.deepcopy(midi_dict)
midi_dict.remove_instruments(remove_instruments=ignore_instruments)
midi_dict.resolve_pedal()
midi_dict.pedal_msgs = []

instruments = [k for k, v in ignore_instruments.items() if not v] + ["drum"]
new_instrument_to_channel, old_channel_to_instrument = (
_create_channel_mappings(
midi_dict,
instruments,
)
)

instrument_msgs = _create_instrument_messages(
instrument_programs,
new_instrument_to_channel,
)

note_msgs = _normalize_note_messages(
midi_dict=midi_dict,
old_channel_to_instrument=old_channel_to_instrument,
new_instrument_to_channel=new_instrument_to_channel,
time_step_ms=time_step_ms,
max_duration_ms=max_duration_ms,
drum_velocity=drum_velocity,
quantize_velocity_fn=quantize_velocity_fn,
)

return MidiDict(
meta_msgs=[],
tempo_msgs=[{"type": "tempo", "data": 500000, "tick": 0}],
pedal_msgs=[],
instrument_msgs=instrument_msgs,
note_msgs=note_msgs,
ticks_per_beat=500,
metadata={},
)
5 changes: 3 additions & 2 deletions ariautils/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Includes Tokenizers and pre-processing utilities."""

from ariautils.tokenizer._base import Tokenizer
from ._base import Tokenizer
from .absolute import AbsTokenizer

__all__ = ["Tokenizer"]
__all__ = ["Tokenizer", "AbsTokenizer"]
7 changes: 1 addition & 6 deletions ariautils/tokenizer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@


class Tokenizer:
"""Abstract Tokenizer class for tokenizing MidiDict objects.

Args:
return_tensors (bool, optional): If True, encode will return tensors.
Defaults to False.
"""
"""Abstract Tokenizer class for tokenizing MidiDict objects."""

def __init__(
self,
Expand Down
58 changes: 49 additions & 9 deletions ariautils/tokenizer/absolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy

from collections import defaultdict
from typing import Final, Callable
from typing import Final, Callable, Any

from ariautils.midi import (
MidiDict,
Expand All @@ -26,25 +26,31 @@

# TODO:
# - Add asserts to the tokenization / detokenization for user error
# - Need to add a tokenization or MidiDict check of how to resolve different
# channels, with the same instrument, have overlaping notes
# - There are tons of edge cases here e.g., what if there are two indetical notes?
# on different channels.


class AbsTokenizer(Tokenizer):
"""MidiDict tokenizer implemented with absolute onset timings.

The tokenizer processes MIDI files in 5000ms segments, with each segment separated by
a special <T> token. Within each segment, note timings are represented relative to the
segment start.
The tokenizer processes MIDI files in 5000ms segments, with each segment
separated by a special <T> token. Within each segment, note timings are
represented relative to the segment start.

Tokenization Schema:
For non-percussion instruments:
- Each note is represented by three consecutive tokens:
1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, and velocity
1. [instrument, pitch, velocity]: Instrument class, MIDI pitch,
and velocity
2. [onset]: Absolute time in milliseconds from segment start
3. [duration]: Note duration in milliseconds

For percussion instruments:
- Each note is represented by two consecutive tokens:
1. [drum, note_number]: Percussion instrument and MIDI note number
1. [drum, note_number]: Percussion instrument and MIDI note
number
2. [onset]: Absolute time in milliseconds from segment start

Notes:
Expand Down Expand Up @@ -140,7 +146,9 @@ def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]:

def _quantize_dur(self, time: int) -> int:
# This function will return values res >= 0 (inc. 0)
return self._find_closest_int(time, self.dur_time_quantizations)
dur = self._find_closest_int(time, self.dur_time_quantizations)

return dur if dur != 0 else self.time_step

def _quantize_onset(self, time: int) -> int:
# This function will return values res >= 0 (inc. 0)
Expand Down Expand Up @@ -337,8 +345,6 @@ def _tokenize_midi_dict(
curr_time_since_onset % self.abs_time_step
)
_note_duration = self._quantize_dur(_note_duration)
if _note_duration == 0:
_note_duration = self.time_step

tokenized_seq.append((_instrument, _pitch, _velocity))
tokenized_seq.append(("onset", _note_onset))
Expand All @@ -349,6 +355,28 @@ def _tokenize_midi_dict(
unformatted_seq=tokenized_seq,
)

def tokenize(
self,
midi_dict: MidiDict,
remove_preceding_silence: bool = True,
**kwargs: Any,
) -> list[Token]:
"""Tokenizes a MidiDict object into a sequence.

Args:
midi_dict (MidiDict): The MidiDict to tokenize.
remove_preceding_silence (bool): If true starts the sequence at
onset=0ms by removing preceding silence. Defaults to False.

Returns:
list[Token]: A sequence of tokens representing the MIDI content.
"""

return self._tokenize_midi_dict(
midi_dict=midi_dict,
remove_preceding_silence=remove_preceding_silence,
)

def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
# NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to
# skip converting between ticks and ms
Expand Down Expand Up @@ -534,6 +562,18 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
metadata={},
)

def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict:
"""Detokenizes a MidiDict object.

Args:
tokenized_seq (list): The sequence of tokens to detokenize.

Returns:
MidiDict: A MidiDict reconstructed from the tokens.
"""

return self._detokenize_midi_dict(tokenized_seq=tokenized_seq)

def export_pitch_aug(
self, aug_range: int
) -> Callable[[list[Token]], list[Token]]:
Expand Down
Binary file added tests/assets/data/basic.mid
Binary file not shown.
Binary file added tests/assets/data/pop.mid
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/test_midi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""""Tests for MidiDict."""

import unittest
import tempfile
import shutil
Expand Down
Loading