Skip to content

Commit

Permalink
new experiments and minor
Browse files Browse the repository at this point in the history
  • Loading branch information
elazarg committed Mar 29, 2021
1 parent 7e8d642 commit 6f1b77e
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 47 deletions.
13 changes: 9 additions & 4 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Tuple, List

import random
import numpy as np

from cachier import cachier

import hebrew
import utils

Expand Down Expand Up @@ -122,11 +124,12 @@ def print_stats(self):
print(self.shapes())


@cachier()
def read_corpora(base_paths):
return [(filename, list(hebrew.iterate_file(filename))) for filename in utils.iterate_files(base_paths)]
return tuple([(filename, list(hebrew.iterate_file(filename))) for filename in utils.iterate_files(base_paths)])


def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True) -> Tuple[Data, Data]:
def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True, subtraining_rate=1) -> Tuple[Data, Data]:
corpus = [(filename, Data.from_text(heb_items, maxlen)) for (filename, heb_items) in corpora]

validation_data = None
Expand All @@ -147,7 +150,9 @@ def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True) -> Tup
validation_data = Data.concatenate(validation)
validation_data.filenames = tuple(validation_filenames)

train = Data.concatenate([c for (_, c) in corpus])
cs = [c for (_, c) in corpus]
random.shuffle(cs)
train = Data.concatenate(cs[:int(subtraining_rate * len(corpus))])
if shuffle:
train.shuffle()
return train, validation_data
Expand Down
83 changes: 80 additions & 3 deletions experiments/ablations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def epoch_params(self, data):


class ModernOnly(TrainingParams):
corpus = {
'modern': (80, [
'hebrew_diacritized/modern',
'hebrew_diacritized/dictaTestCorpus'
])
}

def epoch_params(self, data):
lrs = [30e-4, 30e-4, 30e-4, 8e-4, 1e-4]
yield ('modern', len(lrs), tf.keras.callbacks.LearningRateScheduler(lambda epoch, lr: lrs[epoch]))
Expand Down Expand Up @@ -113,6 +120,67 @@ def name(self):
return f'Batch({self.batch_size})'


class Subtraining(ModernOnly):
def __init__(self, subtraining_rate):
self.subtraining_rate = {'modern': subtraining_rate}

def initialize_weights(self, model):
model.load_weights('./checkpoints/mix')

@property
def name(self):
return f'Subtraining({self.subtraining_rate["modern"]})'


class MultiMaxlen(ModernOnly):
def __init__(self, maxlens, lrs):
self.maxlens = maxlens
self.lrs = lrs
files = [
'hebrew_diacritized/modern',
'hebrew_diacritized/dictaTestCorpus'
]
self.corpus = {f'modern_{maxlen}': (maxlen, files) for maxlen in maxlens}

def initialize_weights(self, model):
model.load_weights('./checkpoints/mix')

def epoch_params(self, data):
for maxlen, lr in zip(self.maxlens, self.lrs):
yield (f'modern_{maxlen}', 1, tf.keras.callbacks.LearningRateScheduler(lambda epoch, _lr: lr))

@property
def name(self):
maxlens = ", ".join(str(x) for x in self.maxlens)
lrs = ", ".join(str(x) for x in self.lrs)
return f'MultiMaxlen({maxlens}; {lrs})'


class Crf(TrainingParams):
def build_model(self):
from tf2crf import CRF, ModelWithCRFLoss
from train import LETTERS_SIZE, NIQQUD_SIZE, DAGESH_SIZE, SIN_SIZE
layers = tf.keras.layers

inp = tf.keras.Input(shape=(None,), batch_size=None)
embed = layers.Embedding(LETTERS_SIZE, self.units, mask_zero=True)(inp)

layer = layers.Bidirectional(layers.LSTM(self.units, return_sequences=True, dropout=0.1), merge_mode='sum')(embed)
layer = layers.Bidirectional(layers.LSTM(self.units, return_sequences=True, dropout=0.1), merge_mode='sum')(layer)
layer = layers.Dense(self.units)(layer)

layer = CRF()(layer)

outputs = [
layers.Dense(NIQQUD_SIZE, name='N')(layer),
layers.Dense(DAGESH_SIZE, name='D')(layer),
layers.Dense(SIN_SIZE, name='S')(layer),
]
base_model = tf.keras.Model(inputs=inp, outputs=outputs)
model = ModelWithCRFLoss(base_model, sparse_target=True)
return model


def calculate_metrics(model):
import nakdimon
for file in Path('tests/validation/expected/modern/').glob('*'):
Expand All @@ -123,16 +191,25 @@ def calculate_metrics(model):
yield metrics.all_metrics(actual, expected)


def train_ablation(params):
def train_ablation(params, group):
def ablation(model):
return metrics.metricwise_mean(calculate_metrics(model))
model = train(params, ablation)
model = train(params, group, ablation)
model.save(f'./models/ablations/{params.name}.h5')


if __name__ == '__main__':
FullTraining(600)
train_ablation(Crf(), group="crf")
# import random
# for _ in range(10):
# n = random.choice([3, 4, 5])
# lrs = [random.choice([1e-4, 5e-4, 10e-4, 20e-4, 30e-4]) for _ in range(n)]
# maxlens = [random.choice([70, 75, 80, 85, 90, 95]) for _ in range(n)]
# train_ablation(MultiMaxlen(maxlens, lrs))
# FullTraining(600)
# from pretrain import Pretrained
# for _ in range(5):
# train_ablation(Pretrained())
# # train_ablation(ModernOnly())
# # train_ablation(FullTraining(400))
# # train_ablation(Chunk(72))
Expand Down
3 changes: 2 additions & 1 deletion experiments/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def all_diffs(system1, system2):

def all_metrics(actual, expected):
return {
'cha': metric_cha(actual, expected),
'dec': metric_dec(actual, expected),
'cha': metric_cha(actual, expected),
'wor': metric_wor(actual, expected),
'voc': metric_wor(actual, expected, vocalize=True)
}
Expand Down Expand Up @@ -201,6 +201,7 @@ def format_latex(sysname, results):
print('{sysname} & {cha:.2%} & {dec:.2%} & {wor:.2%} & {voc:.2%} \\\\'.format(sysname=sysname, **results)
.replace('%', ''))


def all_stats():
SYSTEMS = [
# "Nakdimon",
Expand Down
Binary file added experiments/partial-modern.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
130 changes: 130 additions & 0 deletions experiments/partial_modern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


results = [
# r , cha , dec , wor , voc
[0.1, 0.898022, 0.941791, 0.730089, 0.790349],
[0.1, 0.907983, 0.947288, 0.746686, 0.805641],
[0.1, 0.909339, 0.948013, 0.756906, 0.808338],
[0.1, 0.910473, 0.948737, 0.757871, 0.81078],
[0.1, 0.910393, 0.94853, 0.761774, 0.8147],
[0.2, 0.921098, 0.955005, 0.782228, 0.833589],
[0.2, 0.921384, 0.955211, 0.783685, 0.832101],
[0.2, 0.921859, 0.95526, 0.788762, 0.836588],
[0.2, 0.928058, 0.958799, 0.796022, 0.844013],
[0.2, 0.927519, 0.958676, 0.797943, 0.846828],
[0.3, 0.929646, 0.959993, 0.804771, 0.851261],
[0.3, 0.929712, 0.959984, 0.80654, 0.853802],
[0.3, 0.931107, 0.960884, 0.807965, 0.854223],
[0.3, 0.932764, 0.961357, 0.808793, 0.855598],
[0.3, 0.932484, 0.961512, 0.812267, 0.856918],
[0.4, 0.936546, 0.96386, 0.822286, 0.866843],
[0.4, 0.935857, 0.963207, 0.824132, 0.865466],
[0.4, 0.937018, 0.964095, 0.824572, 0.867156],
[0.4, 0.93854, 0.965087, 0.82832, 0.867609],
[0.4, 0.94117, 0.966382, 0.831514, 0.877483],
[0.5, 0.940108, 0.965908, 0.830636, 0.874788],
[0.5, 0.942086, 0.96687, 0.834623, 0.876376],
[0.5, 0.9431, 0.967516, 0.835859, 0.87881],
[0.5, 0.942016, 0.967019, 0.836506, 0.881143],
[0.5, 0.942875, 0.96751, 0.837655, 0.881893],
[0.6, 0.942889, 0.967467, 0.834149, 0.879263],
[0.6, 0.94241, 0.967134, 0.835282, 0.879975],
[0.6, 0.942999, 0.967409, 0.837874, 0.879422],
[0.6, 0.944271, 0.96841, 0.84033, 0.88526],
[0.6, 0.944287, 0.968416, 0.840918, 0.88462],
[0.7, 0.945244, 0.968931, 0.840901, 0.885507],
[0.7, 0.944691, 0.968439, 0.84286, 0.886661],
[0.7, 0.945565, 0.968863, 0.844009, 0.88787,],
[0.7, 0.94545, 0.969033, 0.84419, 0.888327],
[0.7, 0.947342, 0.970026, 0.846682, 0.890258],
[0.8, 0.94755, 0.970202, 0.847723, 0.891049],
[0.8, 0.948123, 0.970388, 0.849016, 0.891005],
[0.8, 0.947261, 0.970096, 0.849468, 0.891794],
[0.8, 0.948495, 0.970793, 0.850103, 0.892667],
[0.8, 0.947547, 0.970207, 0.850174, 0.892045],
[0.9, 0.95029, 0.971653, 0.853899, 0.898095],
[0.9, 0.949553, 0.971291, 0.855798, 0.89668,],
[0.9, 0.950043, 0.971714, 0.855912, 0.897163],
[0.9, 0.950754, 0.972088, 0.857229, 0.89916,],
[0.9, 0.95033, 0.971725, 0.857454, 0.900025],
[ 1, 0.95109, 0.972268, 0.85568, 0.898633],
[ 1, 0.950829, 0.972174, 0.856407, 0.89867,],
[ 1, 0.95127, 0.972382, 0.858379, 0.90111,],
[ 1, 0.952334, 0.972927, 0.860708, 0.903797],
[ 1, 0.951377, 0.972402, 0.861084, 0.904752],
]

results = np.array([
# r , cha , dec , wor , voc
[0.1, 0.898022, 0.941791, 0.730089, 0.790349],
[0.1, 0.907983, 0.947288, 0.746686, 0.805641],
[0.1, 0.909339, 0.948013, 0.756906, 0.808338],
[0.1, 0.910473, 0.948737, 0.757871, 0.810780],
[0.1, 0.910393, 0.94853, 0.761774, 0.814700],
[0.2, 0.921098, 0.955005, 0.782228, 0.833589],
[0.2, 0.921384, 0.955211, 0.783685, 0.832101],
[0.2, 0.921859, 0.95526, 0.788762, 0.836588],
[0.2, 0.928058, 0.958799, 0.796022, 0.844013],
[0.2, 0.927519, 0.958676, 0.797943, 0.846828],
[0.3, 0.929646, 0.959993, 0.804771, 0.851261],
[0.3, 0.929712, 0.959984, 0.806540, 0.853802],
[0.3, 0.931107, 0.960884, 0.807965, 0.854223],
[0.3, 0.932764, 0.961357, 0.808793, 0.855598],
[0.3, 0.932484, 0.961512, 0.812267, 0.856918],
[0.4, 0.936546, 0.96386, 0.822286, 0.866843],
[0.4, 0.935857, 0.963207, 0.824132, 0.865466],
[0.4, 0.937018, 0.964095, 0.824572, 0.867156],
[0.4, 0.93854, 0.965087, 0.828320, 0.867609],
[0.4, 0.94117, 0.966382, 0.831514, 0.877483],
[0.5, 0.940108, 0.965908, 0.830636, 0.874788],
[0.5, 0.942086, 0.96687, 0.834623, 0.876376],
[0.5, 0.9431, 0.967516, 0.835859, 0.878810],
[0.5, 0.942016, 0.967019, 0.836506, 0.881143],
[0.5, 0.942875, 0.96751, 0.837655, 0.881893],
[0.6, 0.942889, 0.967467, 0.834149, 0.879263],
[0.6, 0.94241, 0.967134, 0.835282, 0.879975],
[0.6, 0.942999, 0.967409, 0.837874, 0.879422],
[0.6, 0.944271, 0.96841, 0.84033, 0.885260],
[0.6, 0.944287, 0.968416, 0.840918, 0.884620],
[0.7, 0.945244, 0.968931, 0.840901, 0.885507],
[0.7, 0.944691, 0.968439, 0.84286, 0.886661],
[0.7, 0.945565, 0.968863, 0.844009, 0.887870],
[0.7, 0.94545, 0.969033, 0.84419, 0.888327],
[0.7, 0.947342, 0.970026, 0.846682, 0.890258],
[0.8, 0.94755, 0.970202, 0.847723, 0.891049],
[0.8, 0.948123, 0.970388, 0.849016, 0.891005],
[0.8, 0.947261, 0.970096, 0.849468, 0.891794],
[0.8, 0.948495, 0.970793, 0.850103, 0.892667],
[0.8, 0.947547, 0.970207, 0.850174, 0.892045],
[0.9, 0.95029, 0.971653, 0.853899, 0.898095],
[0.9, 0.949553, 0.971291, 0.855798, 0.896680],
[0.9, 0.950043, 0.971714, 0.855912, 0.897163],
[0.9, 0.950754, 0.972088, 0.857229, 0.899160],
[0.9, 0.95033, 0.971725, 0.857454, 0.900025],
[ 1, 0.95109, 0.972268, 0.85568, 0.898633],
[ 1, 0.950829, 0.972174, 0.856407, 0.898670],
[ 1, 0.95127, 0.972382, 0.858379, 0.901110],
[ 1, 0.952334, 0.972927, 0.860708, 0.903797],
[ 1, 0.951377, 0.972402, 0.861084, 0.904752],
])
# rs = pd.DataFrame([[results[i, 0], results[i, 3]] for i in range(0, 50, 5)])
#print(rs)
print(results[:, 3])
x = np.round(results[:, 0] * (413 - 40))
ax = sns.lineplot(x=x, y=100 * (1-results[:, 3]), label="Nakdimon", marker='o')
sns.lineplot(x=x, y=([100-91.56] * 50), label="Nakdan")


ax.set(xlabel="Number of modern documents in Nakdimon's training set",
ylabel='WOR error rate')
ax.set(ylim=(0, 26))

ax.xaxis.label.set_fontsize(12)
for l in ax.get_xticklabels():
l.set_fontsize(12)

plt.show()
4 changes: 2 additions & 2 deletions experiments/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def train_ablation(params):

if mode == 'pretrain':
pretrain()
elif mode == 'train_ablation':
train_ablation(PretrainedModernOnly())
# elif mode == 'train_ablation':
# train_ablation(PretrainedModernOnly())
else:
import ablations
tf.config.set_visible_devices([], 'GPU')
Expand Down
Loading

0 comments on commit 6f1b77e

Please sign in to comment.