(gd$0NqYu-FtKl}dgR9}DldvY(>m|EJqYnk4l+&dX$Ux4n*
zVQ4FPx%Y~N%fCJ;e-p{y?w7xNM*jYX@(&l~A4lY$PRc)TmVc?pzdj}ZR+E2!NB-k|
z`OmXv;mk$(F!wF_D0f~K8P^#fGj8Ng$<5q%%Kcl+xB34Lzdz&m=ecRKz?u(>
zlkyQ`k#W7)Egu&fa-+CcZWiB>Pq_9e<5uxmxy`jZ#f*H$_0Nk%v%vKaTQ15+j77$E
z#>b2sjGK&47@sn3F>cR&UG6Y{#`wJDqFG@6a8Bf-xrQvx)#Umd&oXZC|IInx%e_zM
zy5$zvZnO5z+*9(|+_Uoe9M7`;!!ta~SY%vhe9XAPxXJkB3_LU6V%%ojIkQ(jWB&Zi
zvxoot;>9b`pH$wu^2#rwb8_uiUH-IGmp_|*N#^_Ck|>nHmt;wbA|;VFk#!2d0bbWwU51f!Q3G?V;HjG^2Jd3}(#U
zj2S`vR#`MpWXxE`9LYDMb*1^JjVoh;dBU0zYxY_*Xv|Jywi~nAm++S
z@s0{1X)
zOU;}3zt>zls?WOL?eo!Qu5D-DTD)?!p?ZV)9ViNr;Hih51LMr2&cpXXY@4dc%Ak?R
zD$r1oQ;uu|GDbSBtb;^9tJ;mUdi}JZqMTxQKI-Kb{~~;hwl=594yRHGY`_dVSpvCk
zBbyxctyV|?ya@>M)?Mo5EE$3wrRKn>CXC^Sbq^ZV(ozgPP}hti{+mS3?dWO%7w==<
zRl;eZu>-$MJU295fpc0O8MKRLyEQv;#K7zg<`HVqj5z>`EXY=ItG8^xczZ>I?|#rk6i(*k87lDMqkCyKpcE4vNcRAqIX#bXWlTu{q0t^
z8rf`cd0f56NX04=)qRcNi6EOC(1smOkN{e9QV#Bm^K{6iK%tL=vyxPS)30|}-UtI<
z;{|wbkL3lj1&kPd-wu>C$;)+kx9OM5N(CTsSpXB~m(b9{^w>;<2PAC(+=9FHBt0SR|2XZ)&Ljie-
z>#8d5sgPtVg*g}^kp+arW^|g%WVU^Q>Ts`BnC%Amioba$CD||vq#tA&>2na>=HQCl
zn+!IVX7mjtL>Y0+M8B$8fKRfunsSMQSrju3Bbl5Fj8@~QCSw;l;@Hp$GqWrq>_P~0
z#abmY_ghl(eDo|g;R`!g<7DiNj|&hAi1@1FB=I&QEgB}#*K_FaW%T!a4$lwpd=T{t
z`hFfSdlfI6A^6S_d|zNb!+etYwU!u3FQW5rLgSQrdl@t#>9@-F1M_L_pXI*J=Sl=-
zQ1cD!^E$lFGS^b`OZ@*v5|rIfD+XHM36>~;;-IWgn~c~{SLH(-z8Rgz
z{+;5qZk^W2gqjay#HmC722W{Kq%I*tu>o8zg){tU)ENLU9^!4^Wo_)}lLgh(d~^ow
z)X~ywX#9J`$dJ%yP05;eOC+3+zNz&?=<0{S>aUd~}u-uVcU8!j5O5`dVOKWf4oQ;elyOoF+s!qu=JHkJB2T_Zk0-XYS}%5g${7ir!=eyvEHJ5my2MGaF<
zOZ(N9&FK5dvVG2v@n%#?;M!>wV|aTnB7Gk+n%lq4yxEzh+iI=^8$_XXhNwl>ZAL%f
zeo5EiFch|_;>4MsoeLKJK@pU)0M4xgBjK-)ht$h}&Ln~Sy(D?OiLu_tSmzP#hgn=c
z>)dkO#yR9Sc^wjAandjz2{5S&a6>o(?S+iiKI(d5Gx}ZpE(jzP+;QId1z{j>T~Q!f
zG7upMs_U($Bv%OZAkgorz|Ox3HK+uSLqpWuu9^nFcvQWm88x<$vAW_+RWI}ZKCWA`
z6v8yXR2@HP1mHG*Kcv@*wT*q~VvNl5UF_Nk?w(_-M)kB^2yYuYARUed#n1)1x
zlE@z$6yDRK48pOm5Hjk9M$9}jJq1F<2?vo
zIueHEsFx1#~X{u(gPQ>qX;>YjI
zB<~BF1C3mNqf~9QKzaS8iJG}uWpy9iF_sloNXWKAR)!ioFI8*e*kG|E%Zw=b%YD4P
zuU0_JLbcHf&8{>pAf0^FOK9)PD`v6(EEINt?isicsLDB)*A%2q9
zc3Bulb)XB~ulLl^JdqQh1Q`-s9Y(JHgk9!k+0>w@Y8?i3QI+B!o3a;RN1&-oH<7Z@
zY6jeFu8uh{)V%$~mEZ4Kdg;8hsyVNt6hV?d)ADj_CQb|oF
z4F*;-6fTX|%2$uZ?0%52dS7)=_{p$S9bE6YIvF$MtPnG06F$tPqqRJ$%l9{m)rN&~
z`=!ZR`EpEP85Q1*93m3$?)N^L@|hs
z*5IBtp<$i#+{Zx6GMnL#r)wSfgZH>)HXSGbU+`99hbw$J@Oq{Eioh+RuuQd)MOSNL
zH=6(~L`kLCA-AuA-7tmSOn;-j+9*PKZR~neY=A7--SHUg25?)hHrk>2SnT3}Ti`OA
zU{_qES&7{YmW9iWss#Mg@mmpUFJj9A-^Ic0FW6C7a|D@2)8f
zmk}leyCJ-HL3CB@wj{n-f?hCK>+$}UU^sjLyUJ!uez6`rl6gIb?k1hxROXG%=+f_M@Lfi3jN>m3Uu8QpSlhM+Y=eZQkoTJbjnCw?
zN5D5o;M*0G?ez|tf~^nvIY?%KZvmFo#H+&BJ7W^>H?IL(Pg5%d>q0&*9;e=4D|}0-
zXkrIz4!%A@-Tl=IXs?||z*l#w4#9C~o@_;1IkdHA^7R=q2YlONU{n$oGi$(Ciy#@$
zw5CyiZTQ+mdi5D%X>S5kvj&VykR>E&ye}v4&8-ICZpfj`O>id($?PidReyoYjxPw`
z6mTD(e8=jWw6+^-lpYV?`ygFWLHp`CwwCrFT-SbO<@`wKJ|gLo8prA$TUz%8V~wp3
zKsH?|Yu(e0-XB*ZkJYeHp=VhZ{!kyO5$?q
z{JK32$abCmb
znp7P%=%Khhw9YgI8YRWNQs)W&%?{S+xWv0`*7K7Sq(&iX&{wF_*K%_n()8QhZ{30E
z-rV2UJ-V)Q^B+$%=on}Rx~^=ssR+vdtnS$JYyl}sxTpL6PS1y5BFLv4(U5U-iJ;0)
zeNrP~W}-$<-~CJ5Z!M9Wgxp4mS}CUGh46D00Rrej*UR8%FH3+N%7*qHRBtO~k=?W9
z=I)w}G&f{piJS%=Cy|m==WB~xX%@tG%E2mjqW`ok{wj9zTG81{;^!RdftYIEOX6oQ
zk>l9e08gG4+Ec#xyn(8chn9rdkI8eie=vIxp((r08B}(lW
zc)C%&5X-Fst<%qCeOU!*SHS}j_@S2!~98xLwIhtKLURL5(xGW((O9v9mU1h~`
zS1TwwTJ;bDWsFIO%xdG7a*&3@sNSN9MVE#H}-Wd*~MQ)Vj3B(Ow^I?Dzr1iTlq`qyNlFAnmyA+Tzd_
z7L6>LckLF2_3R|?6ZVojG0|REnt8wJ=M8X9&)RoOv=4@EpbUM&I;ZT~PNVywr|EiJ
zI@D;58x+Ln4JBWt`))#8(qyHbu-rY<;z{31Yc~&m1Lx|v|HgD7!j6;{{d
zr|$gPAROj89t!Z5c8WxQo-a*Kq*VgYCba0m
zqi9?qvD2dFCXGOwlx(6qQ46Ky)w%{W=2@!IY*kUc`c&E?3-w4y^E6x3w6H#~=_`mo
zs_1T3dLAFPvWnc7@RcT{5f&L0w;h;4gNaWgs|GUH-bvn8^m3+T^eI7I2+Hx^i}mYN
z;FK?_79ZwPHK-rnbUHTTq$tW^>J1v}2#|U%wn2#|hPG*kNo!Gc)5F5$KnA@qMNhpQ
zK1m7jr%a&jAV;c!J6Kv#IS_$3wBrnk;G~Cq9#^DlJ?}NRFW)Nw``ijUe@r9q0JeagI+NPxGnc9G^PQ@u}l!zjI!ZSG?LZ4Yc$AZIs9<
zO?h~vlo^7c(6uGX7`HH8YgsJ|It}QTFgb|!dZ8er`Dukv2ox^lCXnU?(##;u4ARUX
z%?#2^Ak7TY%plDPq?tjQ8KgO(tv^Le7c`e+@oEmC4yO3i1COkCUwtSgra>kk0jx5t
z?Dqb}s)rp`YlO-|?f?vr!|)UgPr>jM3{S!E0T`Zw;VBp%hv6w0o`T_V{g5H;)IQ{B
z5TlulGio65NL=w=ZzR_ngp39!-zGCSfhwFY*2@02t9ywHH24jiW%M7>N$h5p6FIo
z*0s!~3l->kS{V#z%1e8U7rEUXksomJ;>LCd_Bn5g?v58*Moc&a~2@vP(3?%PTyOgdpjTT(;X^H
zbirG?@bgH&572KA&Os6dUMO>SAy-7a65^E+uZ(zQ#496S5%J21S4O-N;*}AvjCiF0
z)o8avI|ZbX3molVSfO-GI#QRv+&|m;GwtS%G+H0py6IhCc^6rw$F18QWk7wP-H}5;
zerpXv*P+_4;ddZ78%zWA2V(+ZZbY$e9E}RxXrke6V2bNlrZ`hZ-+J|4O@Bt
zW-RJUw-q@)_Fz`&_AJj+EjU_*)|WSE%5RP-{A~ZEEW9uByJ#SuK8mNsc)A`>AIH;;c)A%+pTyIr@pLPmZpYIdH)Zba
Glm7#1XtnVG
literal 0
HcmV?d00001
diff --git a/tests/test_midi.py b/tests/test_midi.py
index 5d6aa02..70cd57a 100644
--- a/tests/test_midi.py
+++ b/tests/test_midi.py
@@ -1,3 +1,5 @@
+""""Tests for MidiDict."""
+
import unittest
import tempfile
import shutil
diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py
new file mode 100644
index 0000000..8c7ed13
--- /dev/null
+++ b/tests/test_tokenizer.py
@@ -0,0 +1,112 @@
+"""Tests for tokenizers."""
+
+import unittest
+import copy
+
+from importlib import resources
+from pathlib import Path
+from typing import Final
+
+from ariautils.midi import MidiDict, normalize_midi_dict
+from ariautils.tokenizer import AbsTokenizer
+from ariautils.utils import get_logger
+
+
+TEST_DATA_DIRECTORY: Final[Path] = Path(
+ str(resources.files("tests").joinpath("assets", "data"))
+)
+RESULTS_DATA_DIRECTORY: Final[Path] = Path(
+ str(resources.files("tests").joinpath("assets", "results"))
+)
+
+
+class TestAbsTokenizer(unittest.TestCase):
+ def setUp(self) -> None:
+ self.logger = get_logger(__name__ + ".TestAbsTokenizer")
+
+ def test_normalize_midi_dict(self) -> None:
+ def _test_normalize_midi_dict(
+ _load_path: Path, _save_path: Path
+ ) -> None:
+ tokenizer = AbsTokenizer()
+ midi_dict = MidiDict.from_midi(_load_path)
+ midi_dict_copy = copy.deepcopy(midi_dict)
+
+ normalized_midi_dict = normalize_midi_dict(
+ midi_dict=midi_dict,
+ ignore_instruments=tokenizer.config["ignore_instruments"],
+ instrument_programs=tokenizer.config["instrument_programs"],
+ time_step_ms=tokenizer.time_step,
+ max_duration_ms=tokenizer.max_dur,
+ drum_velocity=tokenizer.config["drum_velocity"],
+ quantize_velocity_fn=tokenizer._quantize_velocity,
+ )
+ normalized_twice_midi_dict = normalize_midi_dict(
+ normalized_midi_dict,
+ ignore_instruments=tokenizer.config["ignore_instruments"],
+ instrument_programs=tokenizer.config["instrument_programs"],
+ time_step_ms=tokenizer.time_step,
+ max_duration_ms=tokenizer.max_dur,
+ drum_velocity=tokenizer.config["drum_velocity"],
+ quantize_velocity_fn=tokenizer._quantize_velocity,
+ )
+ self.assertDictEqual(
+ normalized_midi_dict.get_msg_dict(),
+ normalized_twice_midi_dict.get_msg_dict(),
+ )
+ self.assertDictEqual(
+ midi_dict.get_msg_dict(),
+ midi_dict_copy.get_msg_dict(),
+ )
+ normalized_midi_dict.to_midi().save(_save_path)
+
+ load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid")
+ save_path = RESULTS_DATA_DIRECTORY.joinpath("arabesque_norm.mid")
+ _test_normalize_midi_dict(load_path, save_path)
+ load_path = TEST_DATA_DIRECTORY.joinpath("pop.mid")
+ save_path = RESULTS_DATA_DIRECTORY.joinpath("pop_norm.mid")
+ _test_normalize_midi_dict(load_path, save_path)
+ load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid")
+ save_path = RESULTS_DATA_DIRECTORY.joinpath("basic_norm.mid")
+ _test_normalize_midi_dict(load_path, save_path)
+
+ def test_tokenize_detokenize(self) -> None:
+ def _test_tokenize_detokenize(_load_path: Path) -> None:
+ tokenizer = AbsTokenizer()
+ midi_dict = MidiDict.from_midi(_load_path)
+
+ midi_dict_1 = normalize_midi_dict(
+ midi_dict=midi_dict,
+ ignore_instruments=tokenizer.config["ignore_instruments"],
+ instrument_programs=tokenizer.config["instrument_programs"],
+ time_step_ms=tokenizer.time_step,
+ max_duration_ms=tokenizer.max_dur,
+ drum_velocity=tokenizer.config["drum_velocity"],
+ quantize_velocity_fn=tokenizer._quantize_velocity,
+ )
+
+ midi_dict_2 = normalize_midi_dict(
+ midi_dict=tokenizer.detokenize(
+ tokenizer.tokenize(
+ midi_dict_1, remove_preceding_silence=False
+ )
+ ),
+ ignore_instruments=tokenizer.config["ignore_instruments"],
+ instrument_programs=tokenizer.config["instrument_programs"],
+ time_step_ms=tokenizer.time_step,
+ max_duration_ms=tokenizer.max_dur,
+ drum_velocity=tokenizer.config["drum_velocity"],
+ quantize_velocity_fn=tokenizer._quantize_velocity,
+ )
+
+ self.assertDictEqual(
+ midi_dict_1.get_msg_dict(),
+ midi_dict_2.get_msg_dict(),
+ )
+
+ load_path = TEST_DATA_DIRECTORY.joinpath("arabesque.mid")
+ _test_tokenize_detokenize(_load_path=load_path)
+ load_path = TEST_DATA_DIRECTORY.joinpath("pop.mid")
+ _test_tokenize_detokenize(_load_path=load_path)
+ load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid")
+ _test_tokenize_detokenize(_load_path=load_path)