From 4ce0401403d74e704a9c0bb9cd5f5a5f72651a2d Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 21 Nov 2024 21:32:40 +0000 Subject: [PATCH] Add AbsTokenizer tests (#4) * add skeleton * port midi.py * update path for maestro metadata json * add tests and ci * add space * update midi tests * add abstract tokenizer class * fix mypy and upgrade to pep 585 * rmv import * fix docstring * migrate abstokenizer * fix mypy * update README * add abstokenizer tests --- README.md | 11 ++- ariautils/midi.py | 159 +++++++++++++++++++++++++++++++- ariautils/tokenizer/__init__.py | 5 +- ariautils/tokenizer/_base.py | 7 +- ariautils/tokenizer/absolute.py | 58 ++++++++++-- tests/assets/data/basic.mid | Bin 0 -> 762 bytes tests/assets/data/pop.mid | Bin 0 -> 12527 bytes tests/test_midi.py | 2 + tests/test_tokenizer.py | 112 ++++++++++++++++++++++ 9 files changed, 334 insertions(+), 20 deletions(-) create mode 100644 tests/assets/data/basic.mid create mode 100644 tests/assets/data/pop.mid create mode 100644 tests/test_tokenizer.py diff --git a/README.md b/README.md index 45a7e1f..e27fd30 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,13 @@ -# aria-utils +

+

+ █████╗ ██████╗ ██╗ █████╗     ██╗   ██╗████████╗██╗██╗     ███████╗
+██╔══██╗██╔══██╗██║██╔══██╗    ██║   ██║╚══██╔══╝██║██║     ██╔════╝
+███████║██████╔╝██║███████║    ██║   ██║   ██║   ██║██║     ███████╗
+██╔══██║██╔══██╗██║██╔══██║    ██║   ██║   ██║   ██║██║     ╚════██║
+██║  ██║██║  ██║██║██║  ██║    ╚██████╔╝   ██║   ██║███████╗███████║
+╚═╝  ╚═╝╚═╝  ╚═╝╚═╝╚═╝  ╚═╝     ╚═════╝    ╚═╝   ╚═╝╚══════╝╚══════╝
+
+

An extremely lightweight and simple library for pre-processing and tokenizing MIDI files. diff --git a/ariautils/midi.py b/ariautils/midi.py index 1813bdf..eab984a 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -4,6 +4,7 @@ import os import json import hashlib +import copy import unicodedata import mido @@ -267,6 +268,8 @@ def _build_pedal_intervals(self) -> dict[int, list[list[int]]]: return channel_to_pedal_intervals + # TODO: This function might not behave correctly when acting on degenerate + # MidiDict objects. def resolve_overlaps(self) -> "MidiDict": """Resolves any note overlaps (inplace) between notes with the same pitch and channel. This is achieved by converting a pair of notes with @@ -465,7 +468,7 @@ def _process_channel_pedals(channel: int) -> None: return self - def remove_instruments(self, config: dict) -> "MidiDict": + def remove_instruments(self, remove_instruments: dict) -> "MidiDict": """Removes all messages with instruments specified in config at: data.preprocessing.remove_instruments @@ -477,7 +480,7 @@ def remove_instruments(self, config: dict) -> "MidiDict": programs_to_remove = [ i for i in range(1, 127 + 1) - if config[self.program_to_instrument[i]] is True + if remove_instruments[self.program_to_instrument[i]] is True ] channels_to_remove = [ msg["channel"] @@ -697,6 +700,8 @@ def midi_to_dict(mid: mido.MidiFile) -> MidiDictData: return midi_dict_data +# POTENTIAL BUG: MidiDict can represent overlapping notes on the same channel. +# When converted to a MIDI files, is this handled correctly? Verify this. def dict_to_midi(mid_data: MidiDictData) -> mido.MidiFile: """Converts MIDI information from dictionary form into a mido.MidiFile. @@ -1151,3 +1156,153 @@ def get_test_fn( ) else: return fn + + +def normalize_midi_dict( + midi_dict: MidiDict, + ignore_instruments: dict[str, bool], + instrument_programs: dict[str, int], + time_step_ms: int, + max_duration_ms: int, + drum_velocity: int, + quantize_velocity_fn: Callable[[int], int], +) -> MidiDict: + """Reorganizes a MidiDict to enable consistent comparisons. + + This function normalizes a MidiDict for testing purposes, ensuring that + equivalent MidiDicts can be compared on a literal basis. It is useful + for validating tokenization or other transformations applied to MIDI + representations. + + Args: + midi_dict (MidiDict): The input MIDI representation to be normalized. + ignore_instruments (dict[str, bool]): A mapping of instrument names to + booleans indicating whether they should be ignored during + normalization. + instrument_programs (dict[str, int]): A mapping of instrument names to + MIDI program numbers used to assign instruments in the output. + time_step_ms (int): The duration of the minimum time step in + milliseconds, used for quantizing note timing. + max_duration_ms (int): The maximum allowable duration for a note in + milliseconds. Notes exceeding this duration will be truncated. + drum_velocity (int): The fixed velocity value assigned to drum notes. + quantize_velocity_fn (Callable[[int], int]): A function that maps input + velocity values to quantized output values. + + Returns: + MidiDict: A normalized version of the input `midi_dict`, reorganized + for consistent and comparable structure. + """ + + def _create_channel_mappings( + midi_dict: MidiDict, instruments: list[str] + ) -> tuple[dict[str, int], dict[int, str]]: + + new_instrument_to_channel = { + instrument: (9 if instrument == "drum" else idx + (idx >= 9)) + for idx, instrument in enumerate(instruments) + } + + old_channel_to_instrument = { + msg["channel"]: midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + } + old_channel_to_instrument[9] = "drum" + + return new_instrument_to_channel, old_channel_to_instrument + + def _create_instrument_messages( + instrument_programs: dict[str, int], + instrument_to_channel: dict[str, int], + ) -> list[InstrumentMessage]: + + return [ + { + "type": "instrument", + "data": instrument_programs[k] if k != "drum" else 0, + "tick": 0, + "channel": v, + } + for k, v in instrument_to_channel.items() + ] + + def _normalize_note_messages( + midi_dict: MidiDict, + old_channel_to_instrument: dict[int, str], + new_instrument_to_channel: dict[str, int], + time_step_ms: int, + max_duration_ms: int, + drum_velocity: int, + quantize_velocity_fn: Callable[[int], int], + ) -> list[NoteMessage]: + + def _quantize_time(_n: int) -> int: + return round(_n / time_step_ms) * time_step_ms + + note_msgs: list[NoteMessage] = [] + for msg in midi_dict.note_msgs: + msg_channel = msg["channel"] + instrument = old_channel_to_instrument[msg_channel] + new_msg_channel = new_instrument_to_channel[instrument] + + start_tick = _quantize_time( + midi_dict.tick_to_ms(msg["data"]["start"]) + ) + end_tick = _quantize_time(midi_dict.tick_to_ms(msg["data"]["end"])) + velocity = quantize_velocity_fn(msg["data"]["velocity"]) + + new_msg = copy.deepcopy(msg) + new_msg["channel"] = new_msg_channel + new_msg["tick"] = start_tick + new_msg["data"]["start"] = start_tick + + if new_msg_channel != 9: + new_msg["data"]["end"] = min( + start_tick + max_duration_ms, end_tick + ) + new_msg["data"]["velocity"] = velocity + else: + new_msg["data"]["end"] = start_tick + time_step_ms + new_msg["data"]["velocity"] = drum_velocity + + note_msgs.append(new_msg) + + return note_msgs + + midi_dict = copy.deepcopy(midi_dict) + midi_dict.remove_instruments(remove_instruments=ignore_instruments) + midi_dict.resolve_pedal() + midi_dict.pedal_msgs = [] + + instruments = [k for k, v in ignore_instruments.items() if not v] + ["drum"] + new_instrument_to_channel, old_channel_to_instrument = ( + _create_channel_mappings( + midi_dict, + instruments, + ) + ) + + instrument_msgs = _create_instrument_messages( + instrument_programs, + new_instrument_to_channel, + ) + + note_msgs = _normalize_note_messages( + midi_dict=midi_dict, + old_channel_to_instrument=old_channel_to_instrument, + new_instrument_to_channel=new_instrument_to_channel, + time_step_ms=time_step_ms, + max_duration_ms=max_duration_ms, + drum_velocity=drum_velocity, + quantize_velocity_fn=quantize_velocity_fn, + ) + + return MidiDict( + meta_msgs=[], + tempo_msgs=[{"type": "tempo", "data": 500000, "tick": 0}], + pedal_msgs=[], + instrument_msgs=instrument_msgs, + note_msgs=note_msgs, + ticks_per_beat=500, + metadata={}, + ) diff --git a/ariautils/tokenizer/__init__.py b/ariautils/tokenizer/__init__.py index 792d910..4596a44 100644 --- a/ariautils/tokenizer/__init__.py +++ b/ariautils/tokenizer/__init__.py @@ -1,5 +1,6 @@ """Includes Tokenizers and pre-processing utilities.""" -from ariautils.tokenizer._base import Tokenizer +from ._base import Tokenizer +from .absolute import AbsTokenizer -__all__ = ["Tokenizer"] +__all__ = ["Tokenizer", "AbsTokenizer"] diff --git a/ariautils/tokenizer/_base.py b/ariautils/tokenizer/_base.py index 35e51ab..f45fc4a 100644 --- a/ariautils/tokenizer/_base.py +++ b/ariautils/tokenizer/_base.py @@ -17,12 +17,7 @@ class Tokenizer: - """Abstract Tokenizer class for tokenizing MidiDict objects. - - Args: - return_tensors (bool, optional): If True, encode will return tensors. - Defaults to False. - """ + """Abstract Tokenizer class for tokenizing MidiDict objects.""" def __init__( self, diff --git a/ariautils/tokenizer/absolute.py b/ariautils/tokenizer/absolute.py index 6ac465b..dbd9878 100644 --- a/ariautils/tokenizer/absolute.py +++ b/ariautils/tokenizer/absolute.py @@ -6,7 +6,7 @@ import copy from collections import defaultdict -from typing import Final, Callable +from typing import Final, Callable, Any from ariautils.midi import ( MidiDict, @@ -26,25 +26,31 @@ # TODO: # - Add asserts to the tokenization / detokenization for user error +# - Need to add a tokenization or MidiDict check of how to resolve different +# channels, with the same instrument, have overlaping notes +# - There are tons of edge cases here e.g., what if there are two indetical notes? +# on different channels. class AbsTokenizer(Tokenizer): """MidiDict tokenizer implemented with absolute onset timings. - The tokenizer processes MIDI files in 5000ms segments, with each segment separated by - a special token. Within each segment, note timings are represented relative to the - segment start. + The tokenizer processes MIDI files in 5000ms segments, with each segment + separated by a special token. Within each segment, note timings are + represented relative to the segment start. Tokenization Schema: For non-percussion instruments: - Each note is represented by three consecutive tokens: - 1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, and velocity + 1. [instrument, pitch, velocity]: Instrument class, MIDI pitch, + and velocity 2. [onset]: Absolute time in milliseconds from segment start 3. [duration]: Note duration in milliseconds For percussion instruments: - Each note is represented by two consecutive tokens: - 1. [drum, note_number]: Percussion instrument and MIDI note number + 1. [drum, note_number]: Percussion instrument and MIDI note + number 2. [onset]: Absolute time in milliseconds from segment start Notes: @@ -140,7 +146,9 @@ def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]: def _quantize_dur(self, time: int) -> int: # This function will return values res >= 0 (inc. 0) - return self._find_closest_int(time, self.dur_time_quantizations) + dur = self._find_closest_int(time, self.dur_time_quantizations) + + return dur if dur != 0 else self.time_step def _quantize_onset(self, time: int) -> int: # This function will return values res >= 0 (inc. 0) @@ -337,8 +345,6 @@ def _tokenize_midi_dict( curr_time_since_onset % self.abs_time_step ) _note_duration = self._quantize_dur(_note_duration) - if _note_duration == 0: - _note_duration = self.time_step tokenized_seq.append((_instrument, _pitch, _velocity)) tokenized_seq.append(("onset", _note_onset)) @@ -349,6 +355,28 @@ def _tokenize_midi_dict( unformatted_seq=tokenized_seq, ) + def tokenize( + self, + midi_dict: MidiDict, + remove_preceding_silence: bool = True, + **kwargs: Any, + ) -> list[Token]: + """Tokenizes a MidiDict object into a sequence. + + Args: + midi_dict (MidiDict): The MidiDict to tokenize. + remove_preceding_silence (bool): If true starts the sequence at + onset=0ms by removing preceding silence. Defaults to False. + + Returns: + list[Token]: A sequence of tokens representing the MIDI content. + """ + + return self._tokenize_midi_dict( + midi_dict=midi_dict, + remove_preceding_silence=remove_preceding_silence, + ) + def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: # NOTE: These values chosen so that 1000 ticks = 1000ms, allowing us to # skip converting between ticks and ms @@ -534,6 +562,18 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: metadata={}, ) + def detokenize(self, tokenized_seq: list[Token], **kwargs: Any) -> MidiDict: + """Detokenizes a MidiDict object. + + Args: + tokenized_seq (list): The sequence of tokens to detokenize. + + Returns: + MidiDict: A MidiDict reconstructed from the tokens. + """ + + return self._detokenize_midi_dict(tokenized_seq=tokenized_seq) + def export_pitch_aug( self, aug_range: int ) -> Callable[[list[Token]], list[Token]]: diff --git a/tests/assets/data/basic.mid b/tests/assets/data/basic.mid new file mode 100644 index 0000000000000000000000000000000000000000..c44fe36a65b7c5e7d536a0c67832c1ddb7ba15d7 GIT binary patch literal 762 zcmbV|u}Z^G7==&T(2BY_=;%UbR#u+fEpgTew{fT0uL%Wy)Mxi4UE)Fcr#NkkDzy-`I{Zn`7#BQ+@ zyJ=5>&^9?u9!O{#tL;C~H`&5q1NNu)d;8vJJ@DyU-}*?uYtL*k0tp;9pO5xuAI(gd$0NqYu-FtKl}dgR9}DldvY(>m|EJqYnk4l+&dX$Ux4n* zVQ4FPx%Y~N%fCJ;e-p{y?w7xNM*jYX@(&l~A4lY$PRc)TmVc?pzdj}ZR+E2!NB-k| z`OmXv;mk$(F!wF_D0f~K8P^#fGj8Ng$<5q%%Kcl+xB34Lzdz&m=ecRKz?u(> zlkyQ`k#W7)Egu&fa-+CcZWiB>Pq_9e<5uxmxy`jZ#f*H$_0Nk%v%vKaTQ15+j77$E z#>b2sjGK&47@sn3F>cR&UG6Y{#`wJDqFG@6a8Bf-xrQvx)#Umd&oXZC|IInx%e_zM zy5$zvZnO5z+*9(|+_Uoe9M7`;!!ta~SY%vhe9XAPxXJkB3_LU6V%%ojIkQ(jWB&Zi zvxoot;>9b`pH$wu^2#rwb8_uiUH-IGmp_|*N#^_Ck|>nHmt;wbA|;VFk#!2d0bbWwU51f!Q3G?V;HjG^2Jd3}(#U zj2S`vR#`MpWXxE`9LYDMb*1^JjVoh;dBU0zYxY_*Xv|Jywi~nAm++S z@s0{1X) zOU;}3zt>zls?WOL?eo!Qu5D-DTD)?!p?ZV)9ViNr;Hih51LMr2&cpXXY@4dc%Ak?R zD$r1oQ;uu|GDbSBtb;^9tJ;mUdi}JZqMTxQKI-Kb{~~;hwl=594yRHGY`_dVSpvCk zBbyxctyV|?ya@>M)?Mo5EE$3wrRKn>CXC^Sbq^ZV(ozgPP}hti{+mS3?dWO%7w==< zRl;eZu>-$MJU295fpc0O8MKRLyEQv;#K7zg<`HVqj5z>`EXY=ItG8^xczZ>I?|#rk6i(*k87lDMqkCyKpcE4vNcRAqIX#bXWlTu{q0t^ z8rf`cd0f56NX04=)qRcNi6EOC(1smOkN{e9QV#Bm^K{6iK%tL=vyxPS)30|}-UtI< z;{|wbkL3lj1&kPd-wu>C$;)+kx9OM5N(CTsSpXB~m(b9{^w>;<2PAC(+=9FHBt0SR|2XZ)&Ljie- z>#8d5sgPtVg*g}^kp+arW^|g%WVU^Q>Ts`BnC%Amioba$CD||vq#tA&>2na>=HQCl zn+!IVX7mjtL>Y0+M8B$8fKRfunsSMQSrju3Bbl5Fj8@~QCSw;l;@Hp$GqWrq>_P~0 z#abmY_ghl(eDo|g;R`!g<7DiNj|&hAi1@1FB=I&QEgB}#*K_FaW%T!a4$lwpd=T{t z`hFfSdlfI6A^6S_d|zNb!+etYwU!u3FQW5rLgSQrdl@t#>9@-F1M_L_pXI*J=Sl=- zQ1cD!^E$lFGS^b`OZ@*v5|rIfD+XHM36>~;;-IWgn~c~{SLH(-z8Rgz z{+;5qZk^W2gqjay#HmC722W{Kq%I*tu>o8zg){tU)ENLU9^!4^Wo_)}lLgh(d~^ow z)X~ywX#9J`$dJ%yP05;eOC+3+zNz&?=<0{S>aUd~}u-uVcU8!j5O5`dVOKWf4oQ;elyOoF+s!qu=JHkJB2T_Zk0-XYS}%5g${7ir!=eyvEHJ5my2MGaF< zOZ(N9&FK5dvVG2v@n%#?;M!>wV|aTnB7Gk+n%lq4yxEzh+iI=^8$_XXhNwl>ZAL%f zeo5EiFch|_;>4MsoeLKJK@pU)0M4xgBjK-)ht$h}&Ln~Sy(D?OiLu_tSmzP#hgn=c z>)dkO#yR9Sc^wjAandjz2{5S&a6>o(?S+iiKI(d5Gx}ZpE(jzP+;QId1z{j>T~Q!f zG7upMs_U($Bv%OZAkgorz|Ox3HK+uSLqpWuu9^nFcvQWm88x<$vAW_+RWI}ZKCWA` z6v8yXR2@HP1mHG*Kcv@*wT*q~VvNl5UF_Nk?w(_-M)kB^2yYuYARUed#n1)1x zlE@z$6yDRK48pOm5Hjk9M$9}jJq1F<2?vo zIueHEsFx1#~X{u(gPQ>qX;>YjI zB<~BF1C3mNqf~9QKzaS8iJG}uWpy9iF_sloNXWKAR)!ioFI8*e*kG|E%Zw=b%YD4P zuU0_JLbcHf&8{>p&#Af0^FOK9)PD`v6(EEINt?isicsLDB)*A%2q9 zc3Bulb)XB~ulLl^JdqQh1Q`-s9Y(JHgk9!k+0>w@Y8?i3QI+B!o3a;RN1&-oH<7Z@ zY6jeFu8uh{)V%$~mEZ4Kdg;8hsyVNt6hV?d)ADj_CQb|oF z4F*;-6fTX|%2$uZ?0%52dS7)=_{p$S9bE6YIvF$MtPnG06F$tPqqRJ$%l9{m)rN&~ z`=!ZR`EpEP85Q1*93m3$?)N^L@|hs z*5IBtp<$i#+{Zx6GMnL#r)wSfgZH>)HXSGbU+`99hbw$J@Oq{Eioh+RuuQd)MOSNL zH=6(~L`kLCA-AuA-7tmSOn;-j+9*PKZR~neY=A7--SHUg25?)hHrk>2SnT3}Ti`OA zU{_qES&7{YmW9iWss#Mg@mmpUFJj9A-^Ic0FW6C7a|D@2)8f zmk}leyCJ-HL3CB@wj{n-f?hCK>+$}UU^sjLyUJ!uez6`rl6gIb?k1hxROXG%=+f_M@Lfi3jN>m3Uu8QpSlhM+Y=eZQkoTJbjnCw? zN5D5o;M*0G?ez|tf~^nvIY?%KZvmFo#H+&BJ7W^>H?IL(Pg5%d>q0&*9;e=4D|}0- zXkrIz4!%A@-Tl=IXs?||z*l#w4#9C~o@_;1IkdHA^7R=q2YlONU{n$oGi$(Ciy#@$ zw5CyiZTQ+mdi5D%X>S5kvj&VykR>E&ye}v4&8-ICZpfj`O>id($?PidReyoYjxPw` z6mTD(e8=jWw6+^-lpYV?`ygFWLHp`CwwCrFT-SbO<@`wKJ|gLo8prA$TUz%8V~wp3 zKsH?|Yu(e0-XB*ZkJYeHp=VhZ{!kyO5$?q z{JK32$abCmb znp7P%=%Khhw9YgI8YRWNQs)W&%?{S+xWv0`*7K7Sq(&iX&{wF_*K%_n()8QhZ{30E z-rV2UJ-V)Q^B+$%=on}Rx~^=ssR+vdtnS$JYyl}sxTpL6PS1y5BFLv4(U5U-iJ;0) zeNrP~W}-$<-~CJ5Z!M9Wgxp4mS}CUGh46D00Rrej*UR8%FH3+N%7*qHRBtO~k=?W9 z=I)w}G&f{piJS%=Cy|m==WB~xX%@tG%E2mjqW`ok{wj9zTG81{;^!RdftYIEOX6oQ zk>l9e08gG4+Ec#xyn(8chn9rdkI8eie=vIxp((r08B}(lW zc)C%&5X-Fst<%qCeOU!*SHS}j_@S2!~98xLwIhtKLURL5(xGW((O9v9mU1h~` zS1TwwTJ;bDWsFIO%xdG7a*&3@sNSN9MVE#H}-Wd*~MQ)Vj3B(Ow^I?Dzr1iTlq`qyNlFAnmyA+Tzd_ z7L6>LckLF2_3R|?6ZVojG0|REnt8wJ=M8X9&)RoOv=4@EpbUM&I;ZT~PNVywr|EiJ zI@D;58x+Ln4JBWt`))#8(qyHbu-rY<;z{31Yc~&m1Lx|v|HgD7!j6;{{d zr|$gPAROj89t!Z5c8WxQo-a*Kq*VgYCba0m zqi9?qvD2dFCXGOwlx(6qQ46Ky)w%{W=2@!IY*kUc`c&E?3-w4y^E6x3w6H#~=_`mo zs_1T3dLAFPvWnc7@RcT{5f&L0w;h;4gNaWgs|GUH-bvn8^m3+T^eI7I2+Hx^i}mYN z;FK?_79ZwPHK-rnbUHTTq$tW^>J1v}2#|U%wn2#|hPG*kNo!Gc)5F5$KnA@qMNhpQ zK1m7jr%a&jAV;c!J6Kv#IS_$3wBrnk;G~Cq9#^DlJ?}NRFW)Nw``ijUe@r9q0JeagI+NPxGnc9G^PQ@u}l!zjI!ZSG?LZ4Yc$AZIs9< zO?h~vlo^7c(6uGX7`HH8YgsJ|It}QTFgb|!dZ8er`Dukv2ox^lCXnU?(##;u4ARUX z%?#2^Ak7TY%plDPq?tjQ8KgO(tv^Le7c`e+@oEmC4yO3i1COkCUwtSgra>kk0jx5t z?Dqb}s)rp`YlO-|?f?vr!|)UgPr>jM3{S!E0T`Zw;VBp%hv6w0o`T_V{g5H;)IQ{B z5TlulGio65NL=w=ZzR_ngp39!-zGCSfhwFY*2@02t9ywHH24jiW%M7>N$h5p6FIo z*0s!~3l->kS{V#z%1e8U7rEUXksomJ;>LCd_Bn5g?v58*Moc&a~2@vP(3?%PTyOgdpjTT(;X^H zbirG?@bgH&572KA&Os6dUMO>SAy-7a65^E+uZ(zQ#496S5%J21S4O-N;*}AvjCiF0 z)o8avI|ZbX3molVSfO-GI#QRv+&|m;GwtS%G+H0py6IhCc^6rw$F18QWk7wP-H}5; zerpXv*P+_4;ddZ78%zWA2V(+ZZbY$e9E}RxXrke6V2bNlrZ`hZ-+J|4O@Bt zW-RJUw-q@)_Fz`&_AJj+EjU_*)|WSE%5RP-{A~ZEEW9uByJ#SuK8mNsc)A`>AIH;;c)A%+pTyIr@pLPmZpYIdH)Zba Glm7#1XtnVG literal 0 HcmV?d00001 diff --git a/tests/test_midi.py b/tests/test_midi.py index 5d6aa02..70cd57a 100644 --- a/tests/test_midi.py +++ b/tests/test_midi.py @@ -1,3 +1,5 @@ +""""Tests for MidiDict.""" + import unittest import tempfile import shutil diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..8c7ed13 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,112 @@ +"""Tests for tokenizers.""" + +import unittest +import copy + +from importlib import resources +from pathlib import Path +from typing import Final + +from ariautils.midi import MidiDict, normalize_midi_dict +from ariautils.tokenizer import AbsTokenizer +from ariautils.utils import get_logger + + +TEST_DATA_DIRECTORY: Final[Path] = Path( + str(resources.files("tests").joinpath("assets", "data")) +) +RESULTS_DATA_DIRECTORY: Final[Path] = Path( + str(resources.files("tests").joinpath("assets", "results")) +) + + +class TestAbsTokenizer(unittest.TestCase): + def setUp(self) -> None: + self.logger = get_logger(__name__ + ".TestAbsTokenizer") + + def test_normalize_midi_dict(self) -> None: + def _test_normalize_midi_dict( + _load_path: Path, _save_path: Path + ) -> None: + tokenizer = AbsTokenizer() + midi_dict = MidiDict.from_midi(_load_path) + midi_dict_copy = copy.deepcopy(midi_dict) + + normalized_midi_dict = normalize_midi_dict( + midi_dict=midi_dict, + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + normalized_twice_midi_dict = normalize_midi_dict( + normalized_midi_dict, + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + self.assertDictEqual( + normalized_midi_dict.get_msg_dict(), + normalized_twice_midi_dict.get_msg_dict(), + ) + self.assertDictEqual( + midi_dict.get_msg_dict(), + midi_dict_copy.get_msg_dict(), + ) + normalized_midi_dict.to_midi().save(_save_path) + + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("arabesque_norm.mid") + _test_normalize_midi_dict(load_path, save_path) + load_path = TEST_DATA_DIRECTORY.joinpath("pop.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("pop_norm.mid") + _test_normalize_midi_dict(load_path, save_path) + load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("basic_norm.mid") + _test_normalize_midi_dict(load_path, save_path) + + def test_tokenize_detokenize(self) -> None: + def _test_tokenize_detokenize(_load_path: Path) -> None: + tokenizer = AbsTokenizer() + midi_dict = MidiDict.from_midi(_load_path) + + midi_dict_1 = normalize_midi_dict( + midi_dict=midi_dict, + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + + midi_dict_2 = normalize_midi_dict( + midi_dict=tokenizer.detokenize( + tokenizer.tokenize( + midi_dict_1, remove_preceding_silence=False + ) + ), + ignore_instruments=tokenizer.config["ignore_instruments"], + instrument_programs=tokenizer.config["instrument_programs"], + time_step_ms=tokenizer.time_step, + max_duration_ms=tokenizer.max_dur, + drum_velocity=tokenizer.config["drum_velocity"], + quantize_velocity_fn=tokenizer._quantize_velocity, + ) + + self.assertDictEqual( + midi_dict_1.get_msg_dict(), + midi_dict_2.get_msg_dict(), + ) + + load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid") + _test_tokenize_detokenize(_load_path=load_path) + load_path = TEST_DATA_DIRECTORY.joinpath("pop.mid") + _test_tokenize_detokenize(_load_path=load_path) + load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid") + _test_tokenize_detokenize(_load_path=load_path)