From 6f1b77e571b933bc7e3f5d671a9098163de16404 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Mon, 29 Mar 2021 20:38:54 +0300 Subject: [PATCH] new experiments and minor --- dataset.py | 13 ++- experiments/ablations.py | 83 ++++++++++++++++++- experiments/metrics.py | 3 +- experiments/partial-modern.png | Bin 0 -> 24875 bytes experiments/partial_modern.py | 130 ++++++++++++++++++++++++++++++ experiments/pretrain.py | 4 +- experiments/train.py | 35 +++++--- index.html | 4 +- main.ipynb | 143 +++++++++++++++++++++++++++------ 9 files changed, 368 insertions(+), 47 deletions(-) create mode 100644 experiments/partial-modern.png create mode 100644 experiments/partial_modern.py diff --git a/dataset.py b/dataset.py index 63607e07..6d0055e2 100644 --- a/dataset.py +++ b/dataset.py @@ -1,7 +1,9 @@ from typing import Tuple, List - +import random import numpy as np +from cachier import cachier + import hebrew import utils @@ -122,11 +124,12 @@ def print_stats(self): print(self.shapes()) +@cachier() def read_corpora(base_paths): - return [(filename, list(hebrew.iterate_file(filename))) for filename in utils.iterate_files(base_paths)] + return tuple([(filename, list(hebrew.iterate_file(filename))) for filename in utils.iterate_files(base_paths)]) -def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True) -> Tuple[Data, Data]: +def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True, subtraining_rate=1) -> Tuple[Data, Data]: corpus = [(filename, Data.from_text(heb_items, maxlen)) for (filename, heb_items) in corpora] validation_data = None @@ -147,7 +150,9 @@ def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True) -> Tup validation_data = Data.concatenate(validation) validation_data.filenames = tuple(validation_filenames) - train = Data.concatenate([c for (_, c) in corpus]) + cs = [c for (_, c) in corpus] + random.shuffle(cs) + train = Data.concatenate(cs[:int(subtraining_rate * len(corpus))]) if shuffle: train.shuffle() return train, validation_data diff --git a/experiments/ablations.py b/experiments/ablations.py index 66cedde8..40e8d079 100644 --- a/experiments/ablations.py +++ b/experiments/ablations.py @@ -85,6 +85,13 @@ def epoch_params(self, data): class ModernOnly(TrainingParams): + corpus = { + 'modern': (80, [ + 'hebrew_diacritized/modern', + 'hebrew_diacritized/dictaTestCorpus' + ]) + } + def epoch_params(self, data): lrs = [30e-4, 30e-4, 30e-4, 8e-4, 1e-4] yield ('modern', len(lrs), tf.keras.callbacks.LearningRateScheduler(lambda epoch, lr: lrs[epoch])) @@ -113,6 +120,67 @@ def name(self): return f'Batch({self.batch_size})' +class Subtraining(ModernOnly): + def __init__(self, subtraining_rate): + self.subtraining_rate = {'modern': subtraining_rate} + + def initialize_weights(self, model): + model.load_weights('./checkpoints/mix') + + @property + def name(self): + return f'Subtraining({self.subtraining_rate["modern"]})' + + +class MultiMaxlen(ModernOnly): + def __init__(self, maxlens, lrs): + self.maxlens = maxlens + self.lrs = lrs + files = [ + 'hebrew_diacritized/modern', + 'hebrew_diacritized/dictaTestCorpus' + ] + self.corpus = {f'modern_{maxlen}': (maxlen, files) for maxlen in maxlens} + + def initialize_weights(self, model): + model.load_weights('./checkpoints/mix') + + def epoch_params(self, data): + for maxlen, lr in zip(self.maxlens, self.lrs): + yield (f'modern_{maxlen}', 1, tf.keras.callbacks.LearningRateScheduler(lambda epoch, _lr: lr)) + + @property + def name(self): + maxlens = ", ".join(str(x) for x in self.maxlens) + lrs = ", ".join(str(x) for x in self.lrs) + return f'MultiMaxlen({maxlens}; {lrs})' + + +class Crf(TrainingParams): + def build_model(self): + from tf2crf import CRF, ModelWithCRFLoss + from train import LETTERS_SIZE, NIQQUD_SIZE, DAGESH_SIZE, SIN_SIZE + layers = tf.keras.layers + + inp = tf.keras.Input(shape=(None,), batch_size=None) + embed = layers.Embedding(LETTERS_SIZE, self.units, mask_zero=True)(inp) + + layer = layers.Bidirectional(layers.LSTM(self.units, return_sequences=True, dropout=0.1), merge_mode='sum')(embed) + layer = layers.Bidirectional(layers.LSTM(self.units, return_sequences=True, dropout=0.1), merge_mode='sum')(layer) + layer = layers.Dense(self.units)(layer) + + layer = CRF()(layer) + + outputs = [ + layers.Dense(NIQQUD_SIZE, name='N')(layer), + layers.Dense(DAGESH_SIZE, name='D')(layer), + layers.Dense(SIN_SIZE, name='S')(layer), + ] + base_model = tf.keras.Model(inputs=inp, outputs=outputs) + model = ModelWithCRFLoss(base_model, sparse_target=True) + return model + + def calculate_metrics(model): import nakdimon for file in Path('tests/validation/expected/modern/').glob('*'): @@ -123,16 +191,25 @@ def calculate_metrics(model): yield metrics.all_metrics(actual, expected) -def train_ablation(params): +def train_ablation(params, group): def ablation(model): return metrics.metricwise_mean(calculate_metrics(model)) - model = train(params, ablation) + model = train(params, group, ablation) model.save(f'./models/ablations/{params.name}.h5') if __name__ == '__main__': - FullTraining(600) + train_ablation(Crf(), group="crf") + # import random + # for _ in range(10): + # n = random.choice([3, 4, 5]) + # lrs = [random.choice([1e-4, 5e-4, 10e-4, 20e-4, 30e-4]) for _ in range(n)] + # maxlens = [random.choice([70, 75, 80, 85, 90, 95]) for _ in range(n)] + # train_ablation(MultiMaxlen(maxlens, lrs)) + # FullTraining(600) + # from pretrain import Pretrained # for _ in range(5): + # train_ablation(Pretrained()) # # train_ablation(ModernOnly()) # # train_ablation(FullTraining(400)) # # train_ablation(Chunk(72)) diff --git a/experiments/metrics.py b/experiments/metrics.py index efdf42ee..5d2c93f7 100644 --- a/experiments/metrics.py +++ b/experiments/metrics.py @@ -131,8 +131,8 @@ def all_diffs(system1, system2): def all_metrics(actual, expected): return { - 'cha': metric_cha(actual, expected), 'dec': metric_dec(actual, expected), + 'cha': metric_cha(actual, expected), 'wor': metric_wor(actual, expected), 'voc': metric_wor(actual, expected, vocalize=True) } @@ -201,6 +201,7 @@ def format_latex(sysname, results): print('{sysname} & {cha:.2%} & {dec:.2%} & {wor:.2%} & {voc:.2%} \\\\'.format(sysname=sysname, **results) .replace('%', '')) + def all_stats(): SYSTEMS = [ # "Nakdimon", diff --git a/experiments/partial-modern.png b/experiments/partial-modern.png new file mode 100644 index 0000000000000000000000000000000000000000..2735d3af1747133b14d82251c89c0781fa75676a GIT binary patch literal 24875 zcmeFZbyQZ{+b+BS0VPEVX$9$$MnY6Tq(Qp7yBi)wL_k2gq)Qs1d zMg)JUE|DmKAf0_N;TH-nTAM%axZF3qLbX|zqp9GrV$ZsJD{6B zF`4#ER0Q`Htx*KPJX0|g=b3buwQHM58KA1`@XdA?KwH< z9ieeG4l`Pj@|vsdGRO=>0FSjO2rPyFHxC;D13Vt}whs6Q-*X5Ff+i3UA3#te*4;+% zCnp$$0zpMcx1K-{jPxN4+_MS&|F`_V-Z2eLgqfKc4HXp?EiEmud3LWNzpOu1bTmzg z4#{PJWmi;#O=kI;BY|G$0j`zYiBa(rn-g~wwkz1zkzJ4dmB%lV?LY=jByq~VbjU9yFmQPX2ttDvOi^|U~u#V^lhqbX0;bx_Mdmc+4fkZjkUPfVX z?gKTUHxdH$P(IKs%slVGuN{yux_l`~ca#$g*n|I9)vcWSkH8nLc`HGyI(r3OW z-!7fKdh82rKM{jQmcsRB9_8j>@fjd!Xs&N;9PIDEY3w$k9LbEbxcK|J11OkfP)@?a zOVyLosS~kMjf0yzDK8Fb&FCx+IIg=*s3H|5$@09t{4mpizVRv=dU-s&zCQeC^a#+E zEdStqUNun-Ooi0aU~06qsVxY;D-tLO{Pruop6AC@R8(zkZGExqt5ND!**Q6SCA2W( zzPJpFjPGVFwb=?;4k~i0jyA7M?QDG?`%v1KCW>`?^!@HYfjsmP7{!o_Py*wt<=3k` zyB@;YM$I@a_h2x{DH1WJy_JA@IFVteDDJHNtldaa#L`p${-u=idK)20n5r9BnQ(pw zmMU~n5`(6LiO+o$Ic{GQt_m3&84cshNJ<(g)4@n4}Gv@y3VxHqNqv!ikP#DIsgd(zW)VeuwIyVYbuOf-;8p~xlwZGhxk#E#0 zlps+nwpv`7U09UJhxioPMna-TKSY4bup2jJdG}{amP}G6|JCKm9F0^Qhi>!xYTmV9 z;cE11-}L>owETu9qVGkgD)Y;s2l=!iSj;?P5F<=bRvMppT!#+Le{54IA}AO76c`vd zXt%#4W`@J0*F=?>%>ye_%hxJwu&Mv9Ix~7(C*Ex_i2GGA$>dlS3oe>S2$?`IxgfKt zvy&6ob~2TnE2+}lL~ncjg1_22as1;s)Oe@PH;);*!1v_-SZFd4&uPe+{XZTZD2rjopsX zVhTGn3|G^p`HdTw`Bh_4j^18^A>#xQsx68^BQU&@!Y_MGu+G!nS!KT(Ow4Vy+#Q)G za$6_p)j==&00uGF=9GD10*~!NlMkw|xX=}5xv%L=TtlYzp{sg`be9Des(wc*Q zKKJ8IG49E7%lJWQWx?M|;n7-d+17LQl&?OWf}4C*_3*hl5hw7~)YPJw)SV9cxQz`A zUa$HOhWB=L`3tyD*-?=?tx@W5`5Kveypd4ABXze$glW8f^OhbD(|9_G-kg}xD_mr2 z1wD`^(C^(neV@_$G}41SF`kmNw`}C(hQ{*M_4V~jN=h0|7Xqz*RQvk*m8jn_b^A4# z;&pi@_paDUSy@?JTU<)5Nm;%uXEM6blC@9bz?dt_u}bs$O5 z>%!%@C@nR$+Gz{^w({rop69;+@pnu3cDvK z1YTAO^72Akk&%%MN_n?2G5h-Zjt&maH?*(Ct@oy`?H0s&dl#G=r|X6v&mjcTTyr?e zeyFZR0rAG>br>3Mc-Yy)SZeQ0kscKfa+xiT@eAr^1c*xB`Jp9Ib8&G2)2Ypeytb=WCk%6X4|l-6bQFB>^&L_) zOYgZppS!*~D)j#5^is%-2(zST!9BT@jNfIKd(MNBk?uJ(9kk{?jb7v|gWJMx)cb(L zL>Ab%Rn;=S+Z7%z?rUq06o$%iX$2aPP)~H5Z3M6$*fj(cExuafw9GNSypcp)R^Q=0=M1%QNyBA z1p5&%9MK#pZXzuF-bQ(Od0A7q;rKDf7fS}1;N_vbQ|v;xl}L-W6DJLgUqO3E$H;P5 zI8C8;Lt@!|BI6P7bWy{~U9!#}&U1CTEhs;8AcS5Y5V*{a`E=^`oEyLWHoTr#@!rT8 znL9O%6=jK^z*6-WnRtW_iw66U50eEO@ zTN?=P{juz#Uo!1i`|EzM#78r0(q9B(A@3R|c~UdJsz~+jW?o%c8U8*op;GNJAaL?g zKQamjwdj)7M}8h(lTM7}h8FF7b&jV$CP0KEOfT&p!lN3Tei7&UZ+{ zWJc{|QO+0s;9ZL#XFLDmS|k<)$i@a={%2}l1wrX)^vgheruFZ+1nBb>b6)ZHI37QK zPqwwt76`(Hbz&w+p+X-paKaqHlgBlkfB4_0c_S|$&8RxuFRc=NuW@K-s99P?@ai

zr8A%45yYoMvF@>fRclD7W zaW~X}bjuTh+B`t;wdiCgGoGLS;f8(YgsXn>@yO35c_S25}nQ7&u6b!G$ z)nBJvY!ft9j~UfRRQaAG!nRZ;{|f~?g0Xanu%9x_^~z-?O%y&$)Op7J^TN~zk3-n1 z5v2QlgTLOsKWEkHJyb8r=<{Q`UniQKZu1E@vsLzry=g!fgS~`oI>S*Pk2nQ^;^*p2b>@(b3%)!;UeiM6_Kceq z!R>hlz25l+E_NQ^H|+yMlv1dHSFDf6JrLsFv5Q~6QOt;R9IiouI*k54l)}|=KdAl< zWT{+CIgz0t3Fs9aW_Q zf$>vBM>Fkr8xxW*F*4iXXM-Yh(YfeQ>+2&z?A$4?(7qJ?om&swISI1v@k>Inw!XGD z5fdC~(Fv?uMV^}J?mYT+So(yUq3NYDO*oTkPUTzfg2($}%&gk51TG#CxnOWkOU-)p z6X=Epg<5qeSR^V+O6_WQ8-sT!ezXwgI;m{%OxU5I8b?S^>hY@_wlLSfh{D46WLDk? z{n|OAAhtsQk%Yk^%l|;C7uuWrZ1PQ%%=LX=9J{f!k`lhTah4#j*TpGF70zd(&U2p5 zCMLA;j?0WzX1Qbb)6XI2%wzTmZjmG4P!UZd73Y4&=Dh^B`$kVQqrM(u?YW`Aj>Y6t zGTLOd4fEIZL_>`K%GH%cMS~B~=2KHs^VKWb+uMzdjlK8U2!2#MqGMwlSKLE99ZV6r zJUcqys+d!$$<1XD*#tq(1X81Qsf!&{by8Ozf%0PynY)BO_RSENEC2y3&7iCL!Spx zSZ)UVQ%S`LL|p*Pbff8n{jkZc_9#ut)TFK{SXqY#i5lm;uLa^GzkHEYmY0`jscc#f zvogC+JZf)VGAKY&BIQ?>)YOrR){*a+g`st-sbW#@i=^mAO_#P`T87DE`2;2up7~Dy zmd*3$P>?O%^?S(5%xtRE^ggCyVZ+tgGtV=JomtPOcgUMcxQ)040&a&NKYqNsol{%; z#p?A5Wod7cAmCDYd3nC*1hZ}%`5cUy?rBs5U!LllV9s}TzAjf$RgGPeE_m-CA3u?q zsE&s?LBCEs-7}Ar;<~9bRCtAJ=v8D=E~I2<>v|tB92=cs1smA^b6!*lLv-1msRA(K zU^qj<*FL?Rym7H3B#Bl3Lr4gAzWbXu_n5@4JHg(3`}WO(H9=Zb`SRjIZD%n6eo%eu zRYwnD`5BiOSM4l;q>OIMLoNqZa#e;7EhLT^htdd*x_}J08>732)xK`FXd{5_$qc{HDmhM2bit z{?YpPK;+$N>jtSiMECB!3~}09XuCW=iN5tj^Re8Mxr6?>%N(YJ)_`$U_c zk>@`$X1=R?jmX2oG5}yPa*2* zRO4h}qS!q7D?pbk3+xEr)*&{f@yT9p(Dsu>nfo1jzVodutct_x`1KouYITd_VTk;TQnN{$uE2c>dpN)r2ly8H5u_=|tdUv0VPsPEdyq&K~SWb(wP82F4 z56p(fbZq6PD-xa2ZWT!-4CqKK&JRlVDU6pyG+|pF-+8Spv4WDBdpV+LsCYE&ehcCQ z@f^-Y%YJP9jue8X)%XZEqjRdg!}oNCLadI^d&0;V<4~G(C*$JW;5&YoF%wb6r%h_> z_R$5}BiI&ybnX*Xv`M;)jn zj6O{{CY15so2w>A7iR0^6B4TZ%cV6}sDq z^_DGWcvKIi!ILM7l<}+vW?@^iR@&69AR@vpO^g94ll$ebMq?pt(oS{j5N+1ivV^g@ z9Qql1-5DSl2YJ+^y2s+Dv#G(G))B?)I3zt>og3VpD9!nc5ew z2T%b;u=W588p5W{RQG%lW8YlGnH-AXu78F1R1#_+*Y@Yg^iTZ^x#p#{g%2Uf+dyt$ z+7EIPj%Z?yV&i!zbvbB6-aT)cl!)1PaiBAlNm`8C6<`Mp2vWXNHG1Im5Jsw?PPiGV zU^jpQtl8L;>_KcA-c*`hkZ5+V!Wt_U3M}{Sr*&6#kmYkT6Age~LKLZq{uSaB*zM;9U#+*ObmF9h*bm$RMvC6hQ6{Q?eObf4=nfieJMhumzU5L6q1bR!0RkbYM@XIUVvoEF(5k}hyy@*&r*SuIH2t%h7KXa&`p{P` ztgeQ$T~uD#IgE-_PW^~bVrNAn`rXFj7qQ|196j#SAzHLD7qAX{B8M`<6T@^j$~U0| zf^ioYtYRRlNcXbyR8^+dfAbIWY`b43%h!kZ^2{Qq8=cP>0b*Iw8o<1*Gl|0%vvu*M zU^Mrr?2j@cYTu^6Y+tqN+)vM#+upAVJv1j7wZBz#A&XK%n}K0vd`n|`Fhb$b_^g<& z{y9N}oEdMuC?R}tzZa4dDPe|UUz+GWx3CHw#31Vf$pgtZ{j-EFvi5m1@ZwKva2fy}EI?OAKT3J$RSr^*mo!Z|e}= zH!jL~cZpRQe2K-QnVp5OZ7wMCqx=jb@Tc*~`x+Vz8?u&LBmJsg7OQ!Bn_rM2yXF@H zqEJu#=cdhb9jC{Pwync?tKJ_U7~rXnU#Qzs`2^>j(25DDhbQrit~1^T!#`#_0L?N@-DL#wg4u&~k_3)i(26cxXz0ywueUO2b6*m3{< z{Uon5Q*pgIE$^TCER8SxG*eer@GTWRi&JL=r9>VP#S}z+X?_u7GQVL*Z8(e~gWn>r zIp9U`xlj7ZpIyfKc=IfL`P}cNZAgfS7@0qJK0gwo^_k48uC7KvLfM?HJAX&4@zDos z<7;}_MkeXSa{ajs2Uv=Pd2GA##j9*)%9!H;u%vN{DG^IV`xm=jNU6_90 zH^@fj&T!+J0*_#pt9{&aB+ootq@{xg?fl7K=--;b@0O0N=c<==<|=Cyki!X6HM^=h zjCE;i-f5nY1p;^BHOW?z)74$I=6+Q}MflcQU?NzcgWL(EjT)fzu(Po}-U4EWt)ru% z^2Wx-xBFv;$$ctXBcC&JFuK#9mt7z&04+sk6U~T~-zaFhMs$HT|334~)7)|dc*xiS zyJ>j05kPlds@#T#hGIo2Da`mX!yf!N$Awn^*49>l(-(|iE5Ex}s-kQUNM5AI%Wq~^ zNDJD{2gc=z3%mV}+vd{LuMCI-s>|C4W)?G7Rn80?A}u7NV_OhvC!ariVmy^)tco#LH)}t#Wk`vwb`O~S}!#_Eij^Y;0S_M zBZlK!eE2Im-Z=bVA}A3Ux>{nKBiM$qNrK15)E4KHnVxuUw&|Vf63i7|7SqO0_N%hO zcRA1B$1NE85NK{?mBJyVsBkMiSKwYGLMoE_@Yj>Yxo{d@_o4@04)JsK^E1CRm@pFp zyL0FoBp5b+v1o?P(m{=Q^YECLfiI zZmU^i7~Lc#tQs+-=c_l7GNok#`{N>-LwzcoW*7{S#2h8`uOJ$0aN2DA-xrFhwA0p{htBnk{{?Yo0;>GFBioEe5ji)zP z7bsW?$%htD51%6GExu*CkJkPDOik>p*ni50UBXmMk~-h&dt;FQp;2PZ<3Os=3?08s>*|#z&>;e&|!csHn`A%I!&e z5ry0S-7pSE3$ICdG%r}irWG))TWuE%qbx3ai4m>v5UiNuCUX?wbdTz3O2w>=ioGDr z;2^qcrV-*vKCB~mud2>^wPG?5pjVkqxT=H*_58U$yShD-UrJH-!;}k((1_o#y1?*N z@lbr*E>e<<_jZC;KW@^7Ge9_hu1@+zRh*S1E?bP2 z@kESSwf1@@KewNJc;9=@Hz8PA7UMiULKdTADLVTQek<0WTfdg4&`Bt+C0X%nIS7|W z=}l0|{l37)Ns}a$)V;U~r(q`Y817U_AwP{;pEN4^+-z8fqUv}|&R3pzGm^E$9?te% zK<)-Sh}!|Q$bWVKWNe_R!XI+?&qJg^B2ZD-PS98gw;-H-AAe3j=v97oh5n9QQbh;1NB1;3agq?aWn?w~LOQH9T zKG5eO1eXi`xg72?>8V_7Kn_b!eq`)QkPdX=}p4Ge#7 z$V|zuD$3(+t&A@XyBd~_(L1eyZ2~MWST`(e9CDNu*1v3q(*lv)nE8t+GvYWq2?nGf z8b{gT&(V531b)9*TrJRcH5J^P$p-x?2#peOa=fB>fur~7%2+w?NE^koF<3SDd z=Z-)=xLB7JGnmdTJg^(RCuf5wUg2!&tbHXQ4&{uYjt}~bbn`|ogTuow7u5n(3JsPg z$9|-nO4eQ}3=6MJs*iqQEd?p&By zS;;AmCiZ-WUF{EFEl2UQSFIBbF zRmVRLd_^Gn6G3WCEzFq5zP~rTbbW&w@bjaEl%>#p;`0C+vutY%apr=L{S0Zobzu_x zVYtTf5)^hEsyX73_2rrM{vXDNuMwc*b~w^SN?Iq1-{(+#(f%VH54U|yyiQGZp2Xv)ea-Z5A4FI)18_hzwOLRKFET5m21vL=p5oV-(8QgaQ{d0Ni3W7S zo^QlypD_i)Ok=yq-Gew17Jnv9U(N~Kt6|}ZPC$6~YQUkXG`g&rL6tB@_d=znC{ImU zb8hf`T?Mpxla^4f6i2zzX+NC{?(B*KnWrqBjzQ~(_jXF(vxVbnrKH&3s*T-gd#EoD zCR`taWXK_WEXA6b+{IV*n5Tq8RZc4}H-)cbRvjt_$W|0;QB29m$f&JOM5&3WB-ljO zw7#HXJzt7#KC<6A1fT{@kkV*mD|euGmauYN<$~qs?@eQA5Rf0mX3Vy;e2N5V3g{J| z?sT)ISI^{*WAxnw|}^ItLq2Alnfw%gS@e{LZ0Y4scTTMB_VATwISiQN(f? zA=TMOk|6K6Zg&6n;((IuoIy!U;>CcLQ-|8-Wyx-U)u_E3nyxe!$EyT?OLca5KfcD7 z$5l>FqvKcCpK)i96B{CnP>qvKXD8cEt4_gjw?=S4x!Ixl+Nh|g^|SWjZeF7DH}6nl zQ92?zB9h-4`RJ?7D_9JHtObaY+bXZK?E`r(!2`<)nk;!xn#qhG zrc~7BpYf|NM*z#Ow76M5%+SInkG3$S>6j z_QC3W%&St|jIbtFq2YW>zLP;hhl$ zh2+Y-Ou&fp3v;b4m~(To&?CC1a@^cdWfjL&rUwl!vhUv`dj9?*8(tQ9?YvmUx)!bY-FHrZTD8-nef>19*?nYf2}r%}KK8*PcyA7e zV`4J+kOeVoK-IjZgfk_U2<-~SaRkJiioG5qeQEax*YrU`>r3~viQZ&o{|b8f+K7%x z2Ai&`l0Qv2DJvUj8vqv;%QVnsnWlbE!Qr~h=%Cb8glYi(m}Z87YDm4^vUv4ig|#+* zBcN`SqQMc2Mv7pkJ8xsM=B|J8S)3Ub@2)!VSrpbsd-LjzRwe_3g5?*Acf~gC!O($I zmjkGiI~SgK7zFO*4It6=YtIr?UC3@$hx-G$+;cJdNJ3>6;DCtB=B8J_bO#OHP@fM; zAdtiq9Lva`XMUB;ddo)55x72&!6{(ZRTXYd6zA~$h{LUYkPyyWXPgjb8oaOb@)&(+ zVW7wA%aLj74++ZNT4K~ zHkaOv&XeloITP}f8Iizt3i>&-} z>@aK#8;gsgisJqSI*Neq9FQ!5+3Bhh5fu#=MT(@k1avc{;byTi2RWG5N~CEKoe$rR zX>Z~nN<30+1faqBsrw-9B!9?soXOAZf_h8OCt7n8xca<@nIWdMGWv##UdFL;?}&QlkvVer15|T^Nt16yOKh>M0ZDufb($(8@=Gu7lneX`bu2W*L?PzLs|=X z?w$C`pkNNB*_9nWT+@7K9kBN)Xmb=7$g79s_M?$hcxB6+?be|g3>bri0Lqu6>r6?8 z@b7mvzD25W+~Cq{{YWO@CYiunmXjl1{e{zZP`8na$o7q_X|I#B8N4CSER#Polma@JE@O!kD)A7CfNCa0g= z!WV69t_q(NZdnG}(sc!m$I;&VA6Kf|cf`+FaE~+DX^H1(U9utui9dp9kSQ(`qM@v* zJw5oIoI@xJut?xDgxcD3pePBX88>%#iB(l{V2lrkMrU4T17N43_2W5CLC#mE;qOc4 zIiUT}U!=XoQ=|QfT1H{nX%i1}Lr-i(25RcnrEBsM^u>Y*kTSvY-*u(jc9+?Yi7efz z=!DoME=NYSGEZ(D;y1apqc6@HySaUyqF-dz?9{L1|6_kwm**Sn>#a;jM@OTyR^Pa= zcrlXGrpRB(FT`P|F>NC=J#)5uw>FB4J1dB?YIQ2Gmf;yXQt?|N_@XUyZ)3*T6Mq-r z@%VxC3*0zSG0E%K6Kio>!y?#16B+0=DlJrh!b6SIR?eI3RwkmL4@7NDW9={e5LA4A zKdDY{J~f{Q3?6N>x}BP+=z~!-k@twjnN2)O8HdR;V_vrCyIm2>Kti%cDjxsx+GYEE zx1lPrPRgMZM>npt>0pn=XO{QDLAL1nfUawM+Q-rr*6EBul zvof`rdjtf|KOD#X)$@XUu#k!bh=k-(lt;N_9}q+l3eE)gj}Iru?uT7E#9SqEdPb0j z=Hu<4<0mK|yIbr<*T?Q$5}k~0526b0u@30RZO<)TE^$BFA@lzMAdiF!8DA%XZ(Wgu zvXyB@j$?E|99KPyUoIvHJNO9j#vI8+{#cY6_~Vz9;Kl6*uS~mP$rCswe_Z54}ywOV0uP77(guSId>AM?qgw7uK6)+TOuK~Xone$W)}%0k7){6 zED#32m6t0`nW=+J0ITKT;8Y&te8Y>sVOFdVxhdLpWxC1py*_VyVs+#>hQ4CfSr{Gq#m!;#yHv?Hk!5MvuiU8B3J z_7L49qgv1|2>R#Vpq#9KH(zWIDy-e_#^jzUvR{>6vm0v#FQY+4OO(4#WIB;UTi7M1 z+)l7=a?4_G_hI)w5z)%o#&lC1r0muV|0x5{{fnj-km#z&VnB+`pPF{FdaN zmGp@&pr4^fNtwRK+%}on6@DnpFzcGF@{ueVVo4f(oh*^mL(NK?ls{<)Dp0eaVd=6h z$=x8PwC8LODO4ZG_!X|lKdWRC^?~9yfJQPuMy_^rdhge~1U_Jy*ufE`y2#b`-MNW7 zr3bqW;6pw#%@~#R7A?)-kQ;yl8tt6Dyq;-3U^95_;J~5T+26lC*U;eblpO&90W`K{ z46$2ogwx2NZL!quz9a}mCHIhZpzHyD0n7ceR5{s~DfCfclq*7sC}cM>D_fkIIi#1< z^PT0?_nsdPW7LI@7hiA#qx^U=P&=$l7NmZvIe|M~QlZYP(3zhrEm+ib4uXW1;N4|A z>LU{cgdfJ|?v6jlyWk`V1cwk(J}!R#FU*p$Y@nOjml-bP4ko`CUhb`|s>)X_SvlQX zTtVPBA7eCR$X8;h4mAXJ3~8Vpvm0tU)`@FNv`k7yKC2%*l~@LLR=8#?*KQqsKXThY zpathQpXabLxs1zCzmMKlC*&1l%{$4j;;`m9Hx6KJ=ak%6p&o}CFJ7idtE+Yd8(Mlj zW4Pd*M7c5NdF*GW^WiejQqV<1$hLrL5AY$p*)?c*dCG{a@nEIzeqRF*@LGtOczQZ_ z`0#si@k}MQQ(Q3xuQstBzyAAfm)oJjzV3Q^KRYj2JlMt(B~>mKt0wu7b|&Xd6FVrS z{Hm@*ryHc(c_ya*cn~WQ8t9$99L<&=Cfs7fo(J`U!3S==F|2_EEKwvTQxx#W@sjf4 zI=Kefa&PJ-*d`UWQfLYXQ+JlqlqvLv{%N_rYG_;XOyof#A-eGT~Xk>&L1P-;W83KKqp!Kl8Nn0&f!u>!DFA_xH|GE5hU>ldcU z$M5scnf7c9yzFeB@@Xg|VT~IC=>h;^{_Ze!>pI2iFR|U--75E|#`v5zlhJ?w{ykc$ z*A_51I0!~E1O*!!K=*FT2h5R;qhVrgP&-0`P#^zxEOsbg^Vx_yqXrc!AUjK`;w0wp zi79SlcB@{y4=-G<8hMiG>~Y8fC7#n;d6!RAuVHThqhEm_QgK9}NLx_iy6HRI7;2z; zi-^Mfx!J@MFqbeoItnaKAGjdc%*y2`^{*Mp${kP*T*+%k82a0wfS9 zuzr;Lnm^vmT~MtEf{0Y&FjD{rC7ZWhjC$+=^+t<=c&dR0y}ynEDj88ZY`}$q<}ayq zMH02K5VgO;GK5=w5_h@O#h_Z%w|Q3+R2b|ea!o+-lbX7Lv0gS>K_f9UFF7Xb*h9+k zZFx12-I}!D$mI@?4-Z!5tstEQvzzfqn(isS;%=<-8Q?uv0N8H#H&A6gb}*K)M3v zoC?V1IoYf8Z z^HS*70H1{4?ggK4fM;1LkCmBw$BCK*UM4}j~=(mk{u zpP#OPF*+sm&8&nTgBFF{;I;bJ$y<8+elBzG?#}yTEGK#04@64Sx4Lvqk>nr=Um;8nVRHEd$}+SBYy$nku2L zD1)?&D}o#p*-{)t^JP~?`FCd2VCQ$rE|+bCkjIO(W85OkNtsxZ%I6t5Gxu;i($KxJ zdd|lL%|SFelNSSkw&L*lXo-9CswW=5XZ*EK(y7V?X^D&3YAc>)f?*ZBItK_U$lrb# zgvB5r0JxZVFO=CBE}uY+RK+eInY-9PtMBIE7?1ZRqhNxIHKYUp@G89%-E!sJhldA8HRR_M zNX~P=OXwH`!g0yk#lR%kOQBVm1E&ZZm4kwn z>PuZ$Fk|$pQ4dha8z2es1cTZ;{?;-lrrL-rF=$_LjmvQ`6alIZn=%j`bk0cPDPr5l zPO(%l;A4JcFYjoMPbl2bw&dKL!m9T&Z7euo!I=n~J3E5T+cTh%&lAkmga{m~=;$On z++kQuZh;_!Ei^$qf^s3;8t4MHU(!N&=VUuIQ7@|zU0Jw2WAY;Mq~s@d2Hs(#BM< zsM@&3LWAf`%% zT7AK9Bje(DLCgbFPyJt>%2$7dOE-TV6a@^lnsk^KRIPj-F%U}RFbxX~MAzi8Te@wyTX!0>yTx{ybRIC|yDwt|b21{~*Ll zrkfCDc@rrAY~gL44$vX-$4x*98G?TQaRl$nqj8!<7ro~9ci>IDaJV1Itn2x&1LL$(?6s{HjWnm%^9L9tBvg^eV!PcNG z>4enwfoJEEGbWb*{#npAyx~Xg^|_n~1GJ7`w))@C`1gqfLfhYSmA$;`OTX)M_VlEw z4lDh)XJ`X4V{dP7Sjh*7y>E@g)bsuSi4yn-{D7=vNJNXPV}2Odau52?|4nl5-f)^| z2r@dpu(0r8J3R{vHafaHXr<3)<>67OF8y!k-38`+tZMh#3n;i}Y+68bK49L2|Jv;s z1!ZMVmy|t{Q%*SIJN2$}QaEalMw+BM?D{)cgS^8=ZD$FIR`7{wttSz>tuTxUFV*JkJhi zfz54AIgWmnV`;e9e~W|~5D<_aSS~@*I5|1#z)4O0D&`!Vwcz!|L9-%)q5VHx2zmg9 zjK2>K-o1ObNVnyKD7ok1lzCys6EZ%hJcpEWeV#A(bGJzb>%^s+> zu(~?F{1ONZ?^2GQlW<@;+1aJO;A$&6j57MKu#3~E$MM?H-d?@RZW)a9%INazHs4yi z_pdOu0a;mDVPRuC+uFvm>Z6>i7CGGcZ(bq)>y@x@Wp#Cb9A~a}gBNZ!7)5ot`n`Ix zH60xtP4k!D_b21xG%`B)VGS6!xHN0Q%-K?y^$oxx49|SDeFq1$2sO52*Cdn&<86TM09jGB4cc9ED`-AJ2yAiL{UYB z_T&Z^B2LtU{0T5s)aTFhTU!F4Ul#tvw4v7oZ09*D+c$|381C1coK$GG-m}I5y%9!U7z`eF=v$b+c2Mt{F34SxEuZ7 zZ%_Z5*ztecOGkS`J7DMuq-3DCoy@vh!^zsQ?>|XmnG~*7ZUfT-6L}!1MX_kl0%73d z!ei|v#XlJkcoHg^i_Y3`M)P+-*7DWL#5uGB{>Dcd?FbrwN`D|f-e$T!@xJDETvr&P zc;)Emh?kFshW51u4ISMVAA&wp<}4`;^z`&lacaLZAZY}Ar}PL=hWAl2&wd+hPJBxh z`Sc|rp$3AW5ZVtn7m+unq@=v;0Z?bVauEys=mHLCgXrn+&vcedb(Sk{oXxgOmL^vN`w&LvZrTt2xCdDmbD(c(xofc^}^vI@(II zVd?mo(&UY6duvw0r!|kAn~Mpo5*2L8XWEV3@~QWosC%EklSX_%R%YhjCdouHEN`&U|s7a$u~S66^zWYA9n?1{$C zrThvsy9RWA75#-!dpuu&v`@fw|WE=sG!U z!-Ul6_>9TAvAb#2wY9aN$X*M|Vg$^FO{1sE)OQ6;}E8cLnp4+^iz%4?Q0Qrtg;kr{s9n4J8Oi-CzL$OzzGiP{VIojOr#GIMkDwxDCfM0EnsrCg=_ROFnToL2Nf z6P=&Uz6)Mu!GYCYY7=EHU|P8S-QSkF~Q% z6pWG5S$zjzzPScTc4xM3`r!cCw@L~W!sCy5k6?WAl9IiEv&$s%6FzuQ3BVv25|O{> z?&30-_xbZ(w*@j@hu5_ifDdMe&b|yF-s<76U_R>vcz*#1KZz11ABA=V9|x5i)KeJv z`A2ufahkiGYzqM9@l(PLEN5lXI4Og^%Sk)T^zQ|k#7wuP0|S67LFN(qAu1X;J{a!8 zV|N>s(nr<++wlSupu}q0sUQU6vova@if6o;G0|SBavTzsF6`2(bMs#%k}T{3+bd3F z+2tY9OUmc;DJsRTrGR%O;B%Sz`0(Pzeh;ML2 zur7d4hhXk+;si42_3zz0dQ3#d%I$1ybX4O|h#pnnkfxU=#$9Yu^61Jhuht-Ig}%cHbeR2gel+ z!~sj1v_25L*cI~jz6MPyyicD_?p6G@%zv$=HQRzMc+p0nJ&3ryy$yGQH*A0??J*|= zmrzWXTc&jgf^zBJBb@h1-XO3{sp7aE-M*XyOjGvZ5am;+%?Y3{N^O*ZcCnlA(;v{8 zE6M`{>pnF4++Kr(g7IrCe;N#9CAuv5gVFIlL^(M*kU@4pA^H015_~BqDKe5O(Oes} zm~HA&b8|=c)%~KJ1K%H-_@=+6dCgQY^Kwy|k1U}eACHcKA&rcX(`*!InGzxPAY8x) zqFr69mR7KeDL@PjFYxg%f@!#=C*XxDhYYG$!y_>FtP5pbE*X!_!~Eza*$BHe_>OA{ zG$X(P(Er!ZyX__&Z%vP9NDzQo8Gu!7&r3~506ltWwI;lY+YHxx`ioKkeg@(qTvIuG z4Il+`f~6-QtIeIElo|P4;7cyOSeQ-W3k!N`h2T>!Yo8?SeW%}(TDE|1K+JS%=>a5s zpSwIm2%We$QACs3>EW9V-4Kl_w>wT-Qz8|Sa#iU~&WfgpH9=a(|_T3;4$KArS`0NyWUPB3YjOWkVg743I=W%B;?& zNNx=_qm&rq2)L;E3K~xj8lpb=9xL6|jFsWljh-2Q(xah_pKZA$|t2m*)R zJ4!Fon}SAqOK1X$bQBbjE&(<4s?t?@AP|a(QX&MTg|0M_Duzf^Q9(MPNWClPe&6$4 ze(gUSLiSo~&%879Ha1?uI2l45t*yt$#zcJaIbdy)wLn=7#F0!Y#j09d#dl!8NIlun zI{iBgPQ!MvYIWTnE7ybnsd~k3qN$R*gOQOjB_(C-hL97&dSl@}TD}~W>9n*ovYG)@Od~l3c*)_)*LMn^X@Qn$p+PnZ zjInDc%^=3j@W`D!&ER{crpef-P8$<2gQnDc0X(o!pfmyxp-Xl*2MUF{AeYKO$nZ94 zY2tA)3Y}e8cnXsDG33t5RM2!|AXxh#h*HO$JWyJgAHF@&<}+0bum(9X5?XLM+2Tef znVFmdo|MX`_8L&&w5tc6ckjO3mjK(W8%)rfG3N-Wr(mtEta^f$Zq>Q4N4A$D z04wpIZrlX1NfEVY833IIn>wpr{a~a}Kr13&oqFaJoF*NU-#{-R2)Y(Px^}qp=OU=5 z6X(0kJ{DXJM~;=6AWlj9;N2!e=SoZO6n9QomN)A7B{zsr| zm89TV2fosc_S0kkkdOw&Y^aK;sA#!qE;$?_TOs7-<$ZT#Lleh_$fJ%#7;X*JlumlJ zPDUp^CZ79tskgMW9D;_!rB2IrOYS7D?B(0YPvp4d+d4ul7sma&a|AjrI5r2)VZ$Rx zviq%g-9P}1$58K+lNoMBhoB9*Wa`A1Uf0yD18WeNenTP?WbJVv@q|^_$OGL|!qS8w zuw=2E8_zkrMxKt_xeBEL-aP9!)z#)XgOh6j>=s9wIiS(Y05d;U%s#%nAam@pEPc?9 z%Oh*Wh9gfzH45cbwz}MM>Vj__8=pRz?H!420yGe~h?W+$(|nKiH%T1D+-mp($}?x) zKe*GjW;2;3?*y|@MDkc%+0+jVe2c~IOAjpb<3$?$=RP?TNmsJNShm0lNa0mUP`?bN zKq|jGZ7VMb_WyOqFVH%G4ss;9K-~_iLh>)OFL<{nQ@r%xguZLV<0~e{J_G;73+GoZ zVVl(TmR45ui!=^se8=w!4gc9INJfAThiz_4HN|VfYXp!nC#(LECDS__8%5<6`w(pJ z6(0iD)+IbM@jKwEd3kvRSqQYs3xdR~&EipO(LxL}an2p8Z;Fd!oSkgKg&zH24N2t- z%w)xW=>RW*F!25WojX>aOKdstZbiMOCWZ>S$UoX@#ynvF-{<<(1}t=eP|I?)akb&= zs}E2ANDOh-C{(p=@Rc0psCjeYw=Vh+v{6a5DMCjD4`!v45nd((%*_@^J#B}yZ;yTT zUUQQqVo$z^198R1AR5OOsXn@my7(z$rpL&h6Xa-tSB6i{8Oh5uin@AUm5I zZ-=4Ls}S|!y;mMzkul=y6%8U0OL~;r78_+w7RJ0ZwP>yAqPq?B6o|<{RUi~$OZt>7 zL;&p&;?wytxs*}0(OKmWcu2CvW@fo^(;AsJ_SVf8h&M=v(QAdpeXh%P4FzDu4gNzkOz{g}BtwTP+gVpxWQYWA8wB8eWMWWXWr2!aLa z)te<7Y;Lw*XCg?WMx^B8LyfR~#VRH!9Q+J%o12=JO7^>)9@zFzH=qs9!iDuUV1F+7 zcpxLJiv5YWr|N?EloD#GV8RxODH%OHu|9E#yRxGIJ$H8)E!P%=Y=Z&}3cM|EP^hnD zrH;M!(X6pHXKd}`=6#^n+TPnM=ce~%>_5~N zhU*N!hs5y_UP+^ZR)ZI6TF2_WZ3DIPYA(;HKx{)^ICKY50T!6Gm&%QdvEod zk}0^H`{T~uo6xOCJC%(pUgwpWTagKkz9!5vmUI?k6SP#N*#c9`MnCyBT;NI2r|UVZ zxuWzh6?Y6$%_+GfA|evUh1?l)SR#=SA;G0y0iqw(j4qa^bD4t`DR~dG!zQ#1#*@wE z=;-JWxz4Vt5-hfFVqz%y&K|Evo~qhK0v}QGD;gJ<89a_s6J$90a$(;8r^v zLamr!{=`$!OoMxL-rp~L1cz<9Q0&Gv2&4AWv3~X>7W1lPP_>ae%cESs?BWSR1IXe| zE{o$8ZC9))_klQw{b$?p?$?Q}NRaC54j%p8Uh2dSlJ;91_Z^E_%6#GL%Y@Hfr*pGS z@$U8dfnn@-g$)vNNo=nSprrG5EH-CD-fG7d$s(UfwM3Ix3~sU-upVP; z9w$uaipNd8M+MLwfI8u`?RQQuSlJXiNnpv|Pu%5^q^@3_7fz(hi3P{AE^|Rafy8-N zZg)5j4{I*T*51d1QdL@-!M$9zgfkZ*_T#Sl?_93ul3!z>L3$bzm<5>|?loOgIPDMj z2b6NtkN>on>i7L#y~Li*K0Xy*mKhZ=bKU1!*b`?d=CTK@70CJjaB+#&-vFQHCAv`tJP_>KlJL z=tnP%a#2%L>-SQpO6J(*Sn>plR;--+B(6_JfLbSmueAeAj?mr|P+Y-n*`G`Jy0yK% z=#)+E=i=aSLGZVU;q6H8gJ9zkoLmNN$iWqs#iyB|i`q}Ll6U7timZKO(yacYC%1rG z>Sa6S!_9B;_+=4(%AZWTucsMTr?_yiq)J^qr<>;f{xxqBK`BfTGnRX~g!j+^pwF|) z>KaURJUZA5Fh=8o*Shx`J6YHi(?Y8G&82Tob8%nOf>e2Qu3wu6(ezSJsOfrk*F|%O@U9AUV+mf?{4c2icpyW)Ok*29Nm`(Baw73Q;@-Z#i~Ar--jyLV2YZ{FGtuCTmVVSa7l)g149rgW8i#9=oRbsauDqNO9Ku-W2Rp1G;iYLAVMF7!KCTffN8Hd-Le z&ZaVsJKwzN$thJJ>_6b*&#Ys3L3zp8QxQt`X}XbwaSTt)|YU9j?ZE!NF=5 zvzy>IqiOAod_KTEFdb5#Y(T6_lP1Fb*KD_T5PG`oQXBG13E}&FT8_Uyl*e>8*-z;8 zZ5v&>hZN|ttH8N64&vSBTh-;|k}$ajcf0jSwt$?$q*J)KQNxL5RL{^zi!Q?rh=X(& zVOf~(9|-de2)rumZ3qTi#IjiyeoXXx{f`kpn*$}p8e)Besf02sQL>2=t>&F&%hMDn zCg5=?y!zPkiK5MOMGMZ_3I7Cx4la&J;JPXy3Hb2-Wo8R-ad8O@+yqA$3Wcg_1(~E? zGf2Y-g_b@_LTRAI>926liXYdYr-1}Jo%o@2Facn!u73zw3Dt-gDoU=SgB|+X(ygQe zTZCR*B|DL!Xt!();lL&Dn1f&GmYBzJV#{}Nw0DMo|HQ}`if2U8u-GpH2m9810eu;rX6V*t<|S4v8AfyU;Pr{e%w7nTI*C3i-vb>Abzykz>{ig zm46TUr)<*D5qb_kcnLGtZAF}(Kw-y)n#Lhso3#7E`QA5dYyj~1nti}q)#omPao5UJ z9bEfLar3gA*y8St`IMp_L+D7g^|J^K&ZWXNUh_kBC+-};^pYnl*ZkUIV*S0 zsY94|1hF)>|Fb+WWuUQ#sknHQF$h5907Wm?(9lDW+XKVJ_q@#4^SzF6u5)>4~egPGgYG*u9?T(`^`frVZWNqMcz(-+sZdG*9+O!7Y&feM#?cz_C zTRuDU@u(3vUQxKE=$7>Bg~`bqV)q~DxHJX`C+^jV_%}voy5}k~iUJx)0 zL82#5syc&JKh@F3f*$20kosEwOyR%idj2d+*;kUig?1Q6gY)yG(rgSnZ+{5h?Nbh>N&n~p-8 zOXHFvJ0R%iRZI*&+JEblMh32-G)Y?U4o^-&L#4{r5gB%~yRm3sYm2LPol1`~JF4Fs z$h?qCn{nRu#s5DJr;$eCM?7vND!cQmF`u4C0WJ%O)MG}K`_quRr&e=@E%Qv5Kz8-T z1&{ax4P_b0+mm`)5zZXT(;8)pcV@9@*#%~|NGuLkF&vo zqJUIpRR8n>5EF__*S~z3F!5Dw#J!g*j*0=^nFJIk1OeeK;Oq8_{C3yZSFZljH42Y6 zm6fu^uuYwa zikpU;ps%ODx2A`T8P*AJyCXaVQq{wd%8*z9-}*La&wHY<&9!7J&{>_`g zG9?QQpsPa|zKx6UGHDQDwM(rVt@~B((iS$=6eL+!_~J#kFFzk&*Lu`y1#f;-1@TLz zUH5fIej0r@0(!3E92L8jFOIb=f@XzP@in34fr6bHvB;`5_|Igz`>7U{&$sm`lLwW= zO+SC@QiIx+{{K5zfeJ>7qZ{&6`c8L)@7E*x#FZi{&;`MfSUEC2ZVjq5lassPQ<6)R znVHwxLxO|<9&=3T(EWR#-NK(QHJ?9!HZU-7>rE|C^|HEt{d$+7jDevP1W@RBT_s9O z${SMP>NeGXf9*f-X|?}*Zz5|H8&;S6*F`DBUE|4Dh7_<4a1+KdnIa>KlYGSkIilyi Xp50+M?@k3uJQOAd=K9sTt}*`sKAo2$ literal 0 HcmV?d00001 diff --git a/experiments/partial_modern.py b/experiments/partial_modern.py new file mode 100644 index 00000000..676fea89 --- /dev/null +++ b/experiments/partial_modern.py @@ -0,0 +1,130 @@ +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + + +results = [ + # r , cha , dec , wor , voc + [0.1, 0.898022, 0.941791, 0.730089, 0.790349], + [0.1, 0.907983, 0.947288, 0.746686, 0.805641], + [0.1, 0.909339, 0.948013, 0.756906, 0.808338], + [0.1, 0.910473, 0.948737, 0.757871, 0.81078], + [0.1, 0.910393, 0.94853, 0.761774, 0.8147], + [0.2, 0.921098, 0.955005, 0.782228, 0.833589], + [0.2, 0.921384, 0.955211, 0.783685, 0.832101], + [0.2, 0.921859, 0.95526, 0.788762, 0.836588], + [0.2, 0.928058, 0.958799, 0.796022, 0.844013], + [0.2, 0.927519, 0.958676, 0.797943, 0.846828], + [0.3, 0.929646, 0.959993, 0.804771, 0.851261], + [0.3, 0.929712, 0.959984, 0.80654, 0.853802], + [0.3, 0.931107, 0.960884, 0.807965, 0.854223], + [0.3, 0.932764, 0.961357, 0.808793, 0.855598], + [0.3, 0.932484, 0.961512, 0.812267, 0.856918], + [0.4, 0.936546, 0.96386, 0.822286, 0.866843], + [0.4, 0.935857, 0.963207, 0.824132, 0.865466], + [0.4, 0.937018, 0.964095, 0.824572, 0.867156], + [0.4, 0.93854, 0.965087, 0.82832, 0.867609], + [0.4, 0.94117, 0.966382, 0.831514, 0.877483], + [0.5, 0.940108, 0.965908, 0.830636, 0.874788], + [0.5, 0.942086, 0.96687, 0.834623, 0.876376], + [0.5, 0.9431, 0.967516, 0.835859, 0.87881], + [0.5, 0.942016, 0.967019, 0.836506, 0.881143], + [0.5, 0.942875, 0.96751, 0.837655, 0.881893], + [0.6, 0.942889, 0.967467, 0.834149, 0.879263], + [0.6, 0.94241, 0.967134, 0.835282, 0.879975], + [0.6, 0.942999, 0.967409, 0.837874, 0.879422], + [0.6, 0.944271, 0.96841, 0.84033, 0.88526], + [0.6, 0.944287, 0.968416, 0.840918, 0.88462], + [0.7, 0.945244, 0.968931, 0.840901, 0.885507], + [0.7, 0.944691, 0.968439, 0.84286, 0.886661], + [0.7, 0.945565, 0.968863, 0.844009, 0.88787,], + [0.7, 0.94545, 0.969033, 0.84419, 0.888327], + [0.7, 0.947342, 0.970026, 0.846682, 0.890258], + [0.8, 0.94755, 0.970202, 0.847723, 0.891049], + [0.8, 0.948123, 0.970388, 0.849016, 0.891005], + [0.8, 0.947261, 0.970096, 0.849468, 0.891794], + [0.8, 0.948495, 0.970793, 0.850103, 0.892667], + [0.8, 0.947547, 0.970207, 0.850174, 0.892045], + [0.9, 0.95029, 0.971653, 0.853899, 0.898095], + [0.9, 0.949553, 0.971291, 0.855798, 0.89668,], + [0.9, 0.950043, 0.971714, 0.855912, 0.897163], + [0.9, 0.950754, 0.972088, 0.857229, 0.89916,], + [0.9, 0.95033, 0.971725, 0.857454, 0.900025], + [ 1, 0.95109, 0.972268, 0.85568, 0.898633], + [ 1, 0.950829, 0.972174, 0.856407, 0.89867,], + [ 1, 0.95127, 0.972382, 0.858379, 0.90111,], + [ 1, 0.952334, 0.972927, 0.860708, 0.903797], + [ 1, 0.951377, 0.972402, 0.861084, 0.904752], +] + +results = np.array([ + # r , cha , dec , wor , voc + [0.1, 0.898022, 0.941791, 0.730089, 0.790349], + [0.1, 0.907983, 0.947288, 0.746686, 0.805641], + [0.1, 0.909339, 0.948013, 0.756906, 0.808338], + [0.1, 0.910473, 0.948737, 0.757871, 0.810780], + [0.1, 0.910393, 0.94853, 0.761774, 0.814700], + [0.2, 0.921098, 0.955005, 0.782228, 0.833589], + [0.2, 0.921384, 0.955211, 0.783685, 0.832101], + [0.2, 0.921859, 0.95526, 0.788762, 0.836588], + [0.2, 0.928058, 0.958799, 0.796022, 0.844013], + [0.2, 0.927519, 0.958676, 0.797943, 0.846828], + [0.3, 0.929646, 0.959993, 0.804771, 0.851261], + [0.3, 0.929712, 0.959984, 0.806540, 0.853802], + [0.3, 0.931107, 0.960884, 0.807965, 0.854223], + [0.3, 0.932764, 0.961357, 0.808793, 0.855598], + [0.3, 0.932484, 0.961512, 0.812267, 0.856918], + [0.4, 0.936546, 0.96386, 0.822286, 0.866843], + [0.4, 0.935857, 0.963207, 0.824132, 0.865466], + [0.4, 0.937018, 0.964095, 0.824572, 0.867156], + [0.4, 0.93854, 0.965087, 0.828320, 0.867609], + [0.4, 0.94117, 0.966382, 0.831514, 0.877483], + [0.5, 0.940108, 0.965908, 0.830636, 0.874788], + [0.5, 0.942086, 0.96687, 0.834623, 0.876376], + [0.5, 0.9431, 0.967516, 0.835859, 0.878810], + [0.5, 0.942016, 0.967019, 0.836506, 0.881143], + [0.5, 0.942875, 0.96751, 0.837655, 0.881893], + [0.6, 0.942889, 0.967467, 0.834149, 0.879263], + [0.6, 0.94241, 0.967134, 0.835282, 0.879975], + [0.6, 0.942999, 0.967409, 0.837874, 0.879422], + [0.6, 0.944271, 0.96841, 0.84033, 0.885260], + [0.6, 0.944287, 0.968416, 0.840918, 0.884620], + [0.7, 0.945244, 0.968931, 0.840901, 0.885507], + [0.7, 0.944691, 0.968439, 0.84286, 0.886661], + [0.7, 0.945565, 0.968863, 0.844009, 0.887870], + [0.7, 0.94545, 0.969033, 0.84419, 0.888327], + [0.7, 0.947342, 0.970026, 0.846682, 0.890258], + [0.8, 0.94755, 0.970202, 0.847723, 0.891049], + [0.8, 0.948123, 0.970388, 0.849016, 0.891005], + [0.8, 0.947261, 0.970096, 0.849468, 0.891794], + [0.8, 0.948495, 0.970793, 0.850103, 0.892667], + [0.8, 0.947547, 0.970207, 0.850174, 0.892045], + [0.9, 0.95029, 0.971653, 0.853899, 0.898095], + [0.9, 0.949553, 0.971291, 0.855798, 0.896680], + [0.9, 0.950043, 0.971714, 0.855912, 0.897163], + [0.9, 0.950754, 0.972088, 0.857229, 0.899160], + [0.9, 0.95033, 0.971725, 0.857454, 0.900025], + [ 1, 0.95109, 0.972268, 0.85568, 0.898633], + [ 1, 0.950829, 0.972174, 0.856407, 0.898670], + [ 1, 0.95127, 0.972382, 0.858379, 0.901110], + [ 1, 0.952334, 0.972927, 0.860708, 0.903797], + [ 1, 0.951377, 0.972402, 0.861084, 0.904752], +]) +# rs = pd.DataFrame([[results[i, 0], results[i, 3]] for i in range(0, 50, 5)]) +#print(rs) +print(results[:, 3]) +x = np.round(results[:, 0] * (413 - 40)) +ax = sns.lineplot(x=x, y=100 * (1-results[:, 3]), label="Nakdimon", marker='o') +sns.lineplot(x=x, y=([100-91.56] * 50), label="Nakdan") + + +ax.set(xlabel="Number of modern documents in Nakdimon's training set", + ylabel='WOR error rate') +ax.set(ylim=(0, 26)) + +ax.xaxis.label.set_fontsize(12) +for l in ax.get_xticklabels(): + l.set_fontsize(12) + +plt.show() diff --git a/experiments/pretrain.py b/experiments/pretrain.py index 2571b75d..de395b70 100644 --- a/experiments/pretrain.py +++ b/experiments/pretrain.py @@ -146,8 +146,8 @@ def train_ablation(params): if mode == 'pretrain': pretrain() - elif mode == 'train_ablation': - train_ablation(PretrainedModernOnly()) + # elif mode == 'train_ablation': + # train_ablation(PretrainedModernOnly()) else: import ablations tf.config.set_visible_devices([], 'GPU') diff --git a/experiments/train.py b/experiments/train.py index ccc35c03..954e21af 100644 --- a/experiments/train.py +++ b/experiments/train.py @@ -26,24 +26,25 @@ class NakdimonParams: def name(self): return type(self).__name__ - maxlen = 80 batch_size = 64 units = 400 corpus = { - 'mix': [ + 'mix': (80, tuple([ 'hebrew_diacritized/poetry', 'hebrew_diacritized/rabanit', 'hebrew_diacritized/pre_modern' - ], - 'modern': [ + ])), + 'modern': (80, tuple([ 'hebrew_diacritized/modern', 'hebrew_diacritized/dictaTestCorpus' - ] + ])) } validation_rate = 0 + subtraining_rate = {'mix': 1, 'modern': 1} + def loss(self, y_true, y_pred): return masked_metric(tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True), y_true) @@ -74,6 +75,9 @@ def build_model(self): ] return tf.keras.Model(inputs=inp, outputs=outputs) + def initialize_weights(self, model): + return + class TrainingParams(NakdimonParams): validation_rate = 0.1 @@ -90,14 +94,17 @@ def get_xy(d): def load_data(params: NakdimonParams): data = {} - for stage_name, stage_dataset_filenames in params.corpus.items(): + for stage_name, (maxlen, stage_dataset_filenames) in params.corpus.items(): np.random.seed(2) - data[stage_name] = dataset.load_data(dataset.read_corpora(stage_dataset_filenames), - validation_rate=params.validation_rate, maxlen=params.maxlen) + data[stage_name] = dataset.load_data(tuple(dataset.read_corpora(tuple(stage_dataset_filenames))), + validation_rate=params.validation_rate, + maxlen=maxlen + # ,subtraining_rate=params.subtraining_rate[stage_name] + ) return data -def train(params: NakdimonParams, ablation=None): +def train(params: NakdimonParams, group, ablation=None): data = load_data(params) @@ -108,20 +115,24 @@ def train(params: NakdimonParams, ablation=None): config = { 'batch_size': params.batch_size, - 'maxlen': params.maxlen, 'units': params.units, 'model': model, + # 'rate_modern': params.subtraining_rate['modern'] } run = wandb.init(project="dotter", - group="ablations_final", + group=group, name=params.name, tags=[], config=config) + + params.initialize_weights(model) + with run: last_epoch = 0 for (stage, n_epochs, scheduler) in params.epoch_params(data): (train, validation) = data[stage] + if validation: with open(f'validation_files_{stage}.txt', 'w') as f: for p in validation.filenames: @@ -152,5 +163,5 @@ class Full(NakdimonParams): if __name__ == '__main__': - model = train(Full()) + model = train(Full(), 'Full') model.save(f'./final_model/final.h5') diff --git a/index.html b/index.html index 7422215c..53d3a94f 100644 --- a/index.html +++ b/index.html @@ -14,12 +14,12 @@

נקדן

-

+

- + diff --git a/main.ipynb b/main.ipynb index 10913e05..d353d2fb 100644 --- a/main.ipynb +++ b/main.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -61,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -106,7 +106,7 @@ " embed = layers.Embedding(LETTERS_SIZE, units, mask_zero=True)(inp)\n", " \n", " layer = layers.Bidirectional(layers.LSTM(units, return_sequences=True, dropout=0.1), merge_mode='sum')(embed)\n", - " # layer = layers.Bidirectional(layers.LSTM(units, return_sequences=True, dropout=0.1), merge_mode='sum')(layer)\n", + " layer = layers.Bidirectional(layers.LSTM(units, return_sequences=True, dropout=0.1), merge_mode='sum')(layer)\n", " layer = layers.Dense(units)(layer)\n", "\n", " outputs = [\n", @@ -168,13 +168,122 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: WANDB_MODE=dryrun\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "wandb: Offline run mode, not syncing to the cloud.\n", + "wandb: W&B is disabled in this directory. Run `wandb on` to enable cloud syncing.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1636/1636 [==============================] - 278s 148ms/step - loss: 1.1109 - N_loss: 0.6972 - D_loss: 0.2073 - S_loss: 0.2064 - N_accuracy: 0.7610 - D_accuracy: 0.9198 - S_accuracy: 0.9426 - val_loss: 0.5110 - val_N_loss: 0.2298 - val_D_loss: 0.1655 - val_S_loss: 0.1157 - val_N_accuracy: 0.9263 - val_D_accuracy: 0.9420 - val_S_accuracy: 0.9722\n", + "letters: 85.48%, decisions: 91.57%, words: 64.75%\n", + "Epoch 2/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.2697 - N_loss: 0.1695 - D_loss: 0.0662 - S_loss: 0.0340 - N_accuracy: 0.9425 - D_accuracy: 0.9753 - S_accuracy: 0.9914 - val_loss: 0.2176 - val_N_loss: 0.1308 - val_D_loss: 0.0517 - val_S_loss: 0.0351 - val_N_accuracy: 0.9566 - val_D_accuracy: 0.9818 - val_S_accuracy: 0.9914\n", + "Epoch 3/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.1543 - N_loss: 0.0970 - D_loss: 0.0408 - S_loss: 0.0165 - N_accuracy: 0.9670 - D_accuracy: 0.9850 - S_accuracy: 0.9957 - val_loss: 0.2053 - val_N_loss: 0.1199 - val_D_loss: 0.0473 - val_S_loss: 0.0381 - val_N_accuracy: 0.9623 - val_D_accuracy: 0.9838 - val_S_accuracy: 0.9915\n", + "Epoch 4/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.1159 - N_loss: 0.0728 - D_loss: 0.0325 - S_loss: 0.0106 - N_accuracy: 0.9751 - D_accuracy: 0.9881 - S_accuracy: 0.9969 - val_loss: 0.2012 - val_N_loss: 0.1177 - val_D_loss: 0.0459 - val_S_loss: 0.0376 - val_N_accuracy: 0.9631 - val_D_accuracy: 0.9836 - val_S_accuracy: 0.9915\n", + "Epoch 5/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.0675 - N_loss: 0.0420 - D_loss: 0.0212 - S_loss: 0.0044 - N_accuracy: 0.9856 - D_accuracy: 0.9924 - S_accuracy: 0.9987 - val_loss: 0.2025 - val_N_loss: 0.1141 - val_D_loss: 0.0470 - val_S_loss: 0.0414 - val_N_accuracy: 0.9682 - val_D_accuracy: 0.9857 - val_S_accuracy: 0.9935\n", + "Epoch 6/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.0460 - N_loss: 0.0288 - D_loss: 0.0156 - S_loss: 0.0016 - N_accuracy: 0.9903 - D_accuracy: 0.9944 - S_accuracy: 0.9995 - val_loss: 0.2055 - val_N_loss: 0.1148 - val_D_loss: 0.0478 - val_S_loss: 0.0429 - val_N_accuracy: 0.9686 - val_D_accuracy: 0.9858 - val_S_accuracy: 0.9935\n", + "letters: 95.92%, decisions: 97.68%, words: 88.22%\n" + ] + }, + { + "data": { + "text/html": [ + "
Waiting for W&B process to finish, PID 21716
Program ended successfully." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find user logs for this run at: wandb\\offline-run-20210121_210719-2z1mn0qk\\logs\\debug.log" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find internal logs for this run at: wandb\\offline-run-20210121_210719-2z1mn0qk\\logs\\debug-internal.log" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Run summary:


\n", + "
loss0.04595
N_loss0.02882
D_loss0.01556
S_loss0.00158
N_accuracy0.99031
D_accuracy0.99442
S_accuracy0.99954
_step69
_runtime479
_timestamp1611256519
epoch5
val_loss0.20546
val_N_loss0.1148
val_D_loss0.0478
val_S_loss0.04286
val_N_accuracy0.96859
val_D_accuracy0.98577
val_S_accuracy0.99345
best_val_loss0.20123
best_epoch3
index0
letters0.95923
decisions0.9768
words0.88224
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Run history:


\n", + "
loss█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
N_loss█▆▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D_loss█▅▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
S_loss█▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
N_accuracy▁▃▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████████████████
D_accuracy▁▆▇▇▇▇▇▇▇▇██████████████████████████████
S_accuracy▁▇▇▇▇███████████████████████████████████
_step▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
_timestamp▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
epoch▁▂▄▅▇█
val_loss█▁▁▁▁▁
val_N_loss█▂▁▁▁▁
val_D_loss█▁▁▁▁▁
val_S_loss█▁▁▁▂▂
val_N_accuracy▁▆▇▇██
val_D_accuracy▁▇████
val_S_accuracy▁▇▇▇██
index
letters
decisions
words

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "wandb: You can sync this run to the cloud by running:\n", + "wandb: wandb sync wandb\\offline-run-20210121_210719-2z1mn0qk\n" + ] + } + ], "source": [ - "%env WANDB_MODE run\n", + "%env WANDB_MODE dryrun\n", "\n", "def experiment(n):\n", " BATCH_SIZE = 64\n", @@ -230,22 +339,10 @@ " run.log({'index': 0, 'letters': letters, 'decisions': decisions, 'words': words})\n", " return model\n", "\n", - "for n in range(5):\n", + "for n in range(1):\n", " model = experiment(n) # 20-30-20-5-1: 88.08-88.16" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "scrolled": true - }, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -491,4 +588,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +}