Skip to content

Commit

Permalink
Add AbsTokenizer tests (#4)
Browse files Browse the repository at this point in the history
* add skeleton

* port midi.py

* update path for maestro metadata json

* add tests and ci

* add space

* update midi tests

* add abstract tokenizer class

* fix mypy and upgrade to pep 585

* rmv import

* fix docstring

* migrate abstokenizer

* fix mypy

* update README

* add abstokenizer tests
  • Loading branch information
loubbrad authored Nov 21, 2024
1 parent 44fa104 commit 4ce0401
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 20 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# aria-utils
<p align="center">
<pre>
█████╗ ██████╗ ██╗ █████╗ ██╗ ██╗████████╗██╗██╗ ███████╗
██╔══██╗██╔══██╗██║██╔══██╗ ██║ ██║╚══██╔══╝██║██║ ██╔════╝
███████║██████╔╝██║███████║ ██║ ██║ ██║ ██║██║ ███████╗
██╔══██║██╔══██╗██║██╔══██║ ██║ ██║ ██║ ██║██║ ╚════██║
██║ ██║██║ ██║██║██║ ██║ ╚██████╔╝ ██║ ██║███████╗███████║
╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚══════╝╚══════╝
</pre>
</p>

An extremely lightweight and simple library for pre-processing and tokenizing MIDI files.

Expand Down
159 changes: 157 additions & 2 deletions ariautils/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import json
import hashlib
import copy
import unicodedata
import mido

Expand Down Expand Up @@ -267,6 +268,8 @@ def _build_pedal_intervals(self) -> dict[int, list[list[int]]]:

return channel_to_pedal_intervals

# TODO: This function might not behave correctly when acting on degenerate
# MidiDict objects.
def resolve_overlaps(self) -> "MidiDict":
"""Resolves any note overlaps (inplace) between notes with the same
pitch and channel. This is achieved by converting a pair of notes with
Expand Down Expand Up @@ -465,7 +468,7 @@ def _process_channel_pedals(channel: int) -> None:

return self

def remove_instruments(self, config: dict) -> "MidiDict":
def remove_instruments(self, remove_instruments: dict) -> "MidiDict":
"""Removes all messages with instruments specified in config at:
data.preprocessing.remove_instruments
Expand All @@ -477,7 +480,7 @@ def remove_instruments(self, config: dict) -> "MidiDict":
programs_to_remove = [
i
for i in range(1, 127 + 1)
if config[self.program_to_instrument[i]] is True
if remove_instruments[self.program_to_instrument[i]] is True
]
channels_to_remove = [
msg["channel"]
Expand Down Expand Up @@ -697,6 +700,8 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData:
return midi_dict_data


# POTENTIAL BUG: MidiDict can represent overlapping notes on the same channel.
# When converted to a MIDI files, is this handled correctly? Verify this.
def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile:
"""Converts MIDI information from dictionary form into a mido.MidiFile.
Expand Down Expand Up @@ -1151,3 +1156,153 @@ def get_test_fn(
)
else:
return fn


def normalize_midi_dict(
midi_dict: MidiDict,
ignore_instruments: dict[str, bool],
instrument_programs: dict[str, int],
time_step_ms: int,
max_duration_ms: int,
drum_velocity: int,
quantize_velocity_fn: Callable[[int], int],
) -> MidiDict:
"""Reorganizes a MidiDict to enable consistent comparisons.
This function normalizes a MidiDict for testing purposes, ensuring that
equivalent MidiDicts can be compared on a literal basis. It is useful
for validating tokenization or other transformations applied to MIDI
representations.
Args:
midi_dict (MidiDict): The input MIDI representation to be normalized.
ignore_instruments (dict[str, bool]): A mapping of instrument names to
booleans indicating whether they should be ignored during
normalization.
instrument_programs (dict[str, int]): A mapping of instrument names to
MIDI program numbers used to assign instruments in the output.
time_step_ms (int): The duration of the minimum time step in
milliseconds, used for quantizing note timing.
max_duration_ms (int): The maximum allowable duration for a note in
milliseconds. Notes exceeding this duration will be truncated.
drum_velocity (int): The fixed velocity value assigned to drum notes.
quantize_velocity_fn (Callable[[int], int]): A function that maps input
velocity values to quantized output values.
Returns:
MidiDict: A normalized version of the input `midi_dict`, reorganized
for consistent and comparable structure.
"""

def _create_channel_mappings(
midi_dict: MidiDict, instruments: list[str]
) -> tuple[dict[str, int], dict[int, str]]:

new_instrument_to_channel = {
instrument: (9 if instrument == "drum" else idx + (idx >= 9))
for idx, instrument in enumerate(instruments)
}

old_channel_to_instrument = {
msg["channel"]: midi_dict.program_to_instrument[msg["data"]]
for msg in midi_dict.instrument_msgs
}
old_channel_to_instrument[9] = "drum"

return new_instrument_to_channel, old_channel_to_instrument

def _create_instrument_messages(
instrument_programs: dict[str, int],
instrument_to_channel: dict[str, int],
) -> list[InstrumentMessage]:

return [
{
"type": "instrument",
"data": instrument_programs[k] if k != "drum" else 0,
"tick": 0,
"channel": v,
}
for k, v in instrument_to_channel.items()
]

def _normalize_note_messages(
midi_dict: MidiDict,
old_channel_to_instrument: dict[int, str],
new_instrument_to_channel: dict[str, int],
time_step_ms: int,
max_duration_ms: int,
drum_velocity: int,
quantize_velocity_fn: Callable[[int], int],
) -> list[NoteMessage]:

def _quantize_time(_n: int) -> int:
return round(_n / time_step_ms) * time_step_ms

note_msgs: list[NoteMessage] = []
for msg in midi_dict.note_msgs:
msg_channel = msg["channel"]
instrument = old_channel_to_instrument[msg_channel]
new_msg_channel = new_instrument_to_channel[instrument]

start_tick = _quantize_time(
midi_dict.tick_to_ms(msg["data"]["start"])
)
end_tick = _quantize_time(midi_dict.tick_to_ms(msg["data"]["end"]))
velocity = quantize_velocity_fn(msg["data"]["velocity"])

new_msg = copy.deepcopy(msg)
new_msg["channel"] = new_msg_channel
new_msg["tick"] = start_tick
new_msg["data"]["start"] = start_tick

if new_msg_channel != 9:
new_msg["data"]["end"] = min(
start_tick + max_duration_ms, end_tick
)
new_msg["data"]["velocity"] = velocity
else:
new_msg["data"]["end"] = start_tick + time_step_ms
new_msg["data"]["velocity"] = drum_velocity

note_msgs.append(new_msg)

return note_msgs

midi_dict = copy.deepcopy(midi_dict)
midi_dict.remove_instruments(remove_instruments=ignore_instruments)
midi_dict.resolve_pedal()
midi_dict.pedal_msgs = []

instruments = [k for k, v in ignore_instruments.items() if not v] + ["drum"]
new_instrument_to_channel, old_channel_to_instrument = (
_create_channel_mappings(
midi_dict,
instruments,
)
)

instrument_msgs = _create_instrument_messages(
instrument_programs,
new_instrument_to_channel,
)

note_msgs = _normalize_note_messages(
midi_dict=midi_dict,
old_channel_to_instrument=old_channel_to_instrument,
new_instrument_to_channel=new_instrument_to_channel,
time_step_ms=time_step_ms,
max_duration_ms=max_duration_ms,
drum_velocity=drum_velocity,
quantize_velocity_fn=quantize_velocity_fn,
)

return MidiDict(
meta_msgs=[],
tempo_msgs=[{"type": "tempo", "data": 500000, "tick": 0}],
pedal_msgs=[],
instrument_msgs=instrument_msgs,
note_msgs=note_msgs,
ticks_per_beat=500,
metadata={},
)
5 changes: 3 additions & 2 deletions ariautils/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Includes Tokenizers and pre-processing utilities."""

from ariautils.tokenizer._base import Tokenizer
from ._base import Tokenizer
from .absolute import AbsTokenizer

__all__ = ["Tokenizer"]
__all__ = ["Tokenizer", "AbsTokenizer"]
7 changes: 1 addition & 6 deletions ariautils/tokenizer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@


class Tokenizer:
"""Abstract Tokenizer class for tokenizing MidiDict objects.
Args:
return_tensors (bool, optional): If True, encode will return tensors.
Defaults to False.
"""
"""Abstract Tokenizer class for tokenizing MidiDict objects."""

def __init__(
self,
Expand Down
58 changes: 49 additions & 9 deletions ariautils/tokenizer/absolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy

from collections import defaultdict
from typing import Final, Callable
from typing import Final, Callable, Any

from ariautils.midi import (
MidiDict,
Expand All @@ -26,25 +26,31 @@

# TODO:
# - Add asserts to the tokenization / detokenization for user error
# - Need to add a tokenization or MidiDict check of how to resolve different
# channels, with the same instrument, have overlaping notes
# - There are tons of edge cases here e.g., what if there are two indetical notes?
# on different channels.


class AbsTokenizer(Tokenizer):
"""MidiDict tokenizer implemented with absolute onset timings.
The tokenizer processes MIDI files in 5000ms segments, with each segment separated by
a special <T> token. Within each segment, note timings are represented relative to the
segment start.
The tokenizer processes MIDI files in 5000ms segments, with each segment
separated by a special <T> token. Within each segment, note timings are
represented relative to the segment start.
Tokenization Schema:
For non-percussion instruments:
- Each note is represented by three consecutive tokens:
1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, and velocity
1. [instrument, pitch, velocity]: Instrument class, MIDI pitch,
and velocity
2. [onset]: Absolute time in milliseconds from segment start
3. [duration]: Note duration in milliseconds
For percussion instruments:
- Each note is represented by two consecutive tokens:
1. [drum, note_number]: Percussion instrument and MIDI note number
1. [drum, note_number]: Percussion instrument and MIDI note
number
2. [onset]: Absolute time in milliseconds from segment start
Notes:
Expand Down Expand Up @@ -140,7 +146,9 @@ def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]:

def _quantize_dur(self, time: int) -> int:
# This function will return values res >= 0 (inc. 0)
return self._find_closest_int(time, self.dur_time_quantizations)
dur = self._find_closest_int(time, self.dur_time_quantizations)

return dur if dur != 0 else self.time_step

def _quantize_onset(self, time: int) -> int:
# This function will return values res >= 0 (inc. 0)
Expand Down Expand Up @@ -337,8 +345,6 @@ def _tokenize_midi_dict(
curr_time_since_onset % self.abs_time_step
)
_note_duration = self._quantize_dur(_note_duration)
if _note_duration == 0:
_note_duration = self.time_step

tokenized_seq.append((_instrument, _pitch, _velocity))
tokenized_seq.append(("onset", _note_onset))
Expand All @@ -349,6 +355,28 @@ def _tokenize_midi_dict(
unformatted_seq=tokenized_seq,
)

def tokenize(
self,
midi_dict: MidiDict,
remove_preceding_silence: bool = True,
**kwargs: Any,
) -> list[Token]:
"""Tokenizes a MidiDict object into a sequence.
Args:
midi_dict (MidiDict): The MidiDict to tokenize.
remove_preceding_silence (bool): If true starts the sequence at
onset=0ms by removing preceding silence. Defaults to False.
Returns:
list[Token]: A sequence of tokens representing the MIDI content.
"""

return self._tokenize_midi_dict(
midi_dict=midi_dict,
remove_preceding_silence=remove_preceding_silence,
)

def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
# NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to
# skip converting between ticks and ms
Expand Down Expand Up @@ -534,6 +562,18 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
metadata={},
)

def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict:
"""Detokenizes a MidiDict object.
Args:
tokenized_seq (list): The sequence of tokens to detokenize.
Returns:
MidiDict: A MidiDict reconstructed from the tokens.
"""

return self._detokenize_midi_dict(tokenized_seq=tokenized_seq)

def export_pitch_aug(
self, aug_range: int
) -> Callable[[list[Token]], list[Token]]:
Expand Down
Binary file added tests/assets/data/basic.mid
Binary file not shown.
Binary file added tests/assets/data/pop.mid
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/test_midi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""""Tests for MidiDict."""

import unittest
import tempfile
import shutil
Expand Down
Loading

0 comments on commit 4ce0401

Please sign in to comment.