diff --git a/dataset.py b/dataset.py index 63607e07..6d0055e2 100644 --- a/dataset.py +++ b/dataset.py @@ -1,7 +1,9 @@ from typing import Tuple, List - +import random import numpy as np +from cachier import cachier + import hebrew import utils @@ -122,11 +124,12 @@ def print_stats(self): print(self.shapes()) +@cachier() def read_corpora(base_paths): - return [(filename, list(hebrew.iterate_file(filename))) for filename in utils.iterate_files(base_paths)] + return tuple([(filename, list(hebrew.iterate_file(filename))) for filename in utils.iterate_files(base_paths)]) -def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True) -> Tuple[Data, Data]: +def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True, subtraining_rate=1) -> Tuple[Data, Data]: corpus = [(filename, Data.from_text(heb_items, maxlen)) for (filename, heb_items) in corpora] validation_data = None @@ -147,7 +150,9 @@ def load_data(corpora, validation_rate: float, maxlen: int, shuffle=True) -> Tup validation_data = Data.concatenate(validation) validation_data.filenames = tuple(validation_filenames) - train = Data.concatenate([c for (_, c) in corpus]) + cs = [c for (_, c) in corpus] + random.shuffle(cs) + train = Data.concatenate(cs[:int(subtraining_rate * len(corpus))]) if shuffle: train.shuffle() return train, validation_data diff --git a/experiments/ablations.py b/experiments/ablations.py index 66cedde8..40e8d079 100644 --- a/experiments/ablations.py +++ b/experiments/ablations.py @@ -85,6 +85,13 @@ def epoch_params(self, data): class ModernOnly(TrainingParams): + corpus = { + 'modern': (80, [ + 'hebrew_diacritized/modern', + 'hebrew_diacritized/dictaTestCorpus' + ]) + } + def epoch_params(self, data): lrs = [30e-4, 30e-4, 30e-4, 8e-4, 1e-4] yield ('modern', len(lrs), tf.keras.callbacks.LearningRateScheduler(lambda epoch, lr: lrs[epoch])) @@ -113,6 +120,67 @@ def name(self): return f'Batch({self.batch_size})' +class Subtraining(ModernOnly): + def __init__(self, subtraining_rate): + self.subtraining_rate = {'modern': subtraining_rate} + + def initialize_weights(self, model): + model.load_weights('./checkpoints/mix') + + @property + def name(self): + return f'Subtraining({self.subtraining_rate["modern"]})' + + +class MultiMaxlen(ModernOnly): + def __init__(self, maxlens, lrs): + self.maxlens = maxlens + self.lrs = lrs + files = [ + 'hebrew_diacritized/modern', + 'hebrew_diacritized/dictaTestCorpus' + ] + self.corpus = {f'modern_{maxlen}': (maxlen, files) for maxlen in maxlens} + + def initialize_weights(self, model): + model.load_weights('./checkpoints/mix') + + def epoch_params(self, data): + for maxlen, lr in zip(self.maxlens, self.lrs): + yield (f'modern_{maxlen}', 1, tf.keras.callbacks.LearningRateScheduler(lambda epoch, _lr: lr)) + + @property + def name(self): + maxlens = ", ".join(str(x) for x in self.maxlens) + lrs = ", ".join(str(x) for x in self.lrs) + return f'MultiMaxlen({maxlens}; {lrs})' + + +class Crf(TrainingParams): + def build_model(self): + from tf2crf import CRF, ModelWithCRFLoss + from train import LETTERS_SIZE, NIQQUD_SIZE, DAGESH_SIZE, SIN_SIZE + layers = tf.keras.layers + + inp = tf.keras.Input(shape=(None,), batch_size=None) + embed = layers.Embedding(LETTERS_SIZE, self.units, mask_zero=True)(inp) + + layer = layers.Bidirectional(layers.LSTM(self.units, return_sequences=True, dropout=0.1), merge_mode='sum')(embed) + layer = layers.Bidirectional(layers.LSTM(self.units, return_sequences=True, dropout=0.1), merge_mode='sum')(layer) + layer = layers.Dense(self.units)(layer) + + layer = CRF()(layer) + + outputs = [ + layers.Dense(NIQQUD_SIZE, name='N')(layer), + layers.Dense(DAGESH_SIZE, name='D')(layer), + layers.Dense(SIN_SIZE, name='S')(layer), + ] + base_model = tf.keras.Model(inputs=inp, outputs=outputs) + model = ModelWithCRFLoss(base_model, sparse_target=True) + return model + + def calculate_metrics(model): import nakdimon for file in Path('tests/validation/expected/modern/').glob('*'): @@ -123,16 +191,25 @@ def calculate_metrics(model): yield metrics.all_metrics(actual, expected) -def train_ablation(params): +def train_ablation(params, group): def ablation(model): return metrics.metricwise_mean(calculate_metrics(model)) - model = train(params, ablation) + model = train(params, group, ablation) model.save(f'./models/ablations/{params.name}.h5') if __name__ == '__main__': - FullTraining(600) + train_ablation(Crf(), group="crf") + # import random + # for _ in range(10): + # n = random.choice([3, 4, 5]) + # lrs = [random.choice([1e-4, 5e-4, 10e-4, 20e-4, 30e-4]) for _ in range(n)] + # maxlens = [random.choice([70, 75, 80, 85, 90, 95]) for _ in range(n)] + # train_ablation(MultiMaxlen(maxlens, lrs)) + # FullTraining(600) + # from pretrain import Pretrained # for _ in range(5): + # train_ablation(Pretrained()) # # train_ablation(ModernOnly()) # # train_ablation(FullTraining(400)) # # train_ablation(Chunk(72)) diff --git a/experiments/metrics.py b/experiments/metrics.py index efdf42ee..5d2c93f7 100644 --- a/experiments/metrics.py +++ b/experiments/metrics.py @@ -131,8 +131,8 @@ def all_diffs(system1, system2): def all_metrics(actual, expected): return { - 'cha': metric_cha(actual, expected), 'dec': metric_dec(actual, expected), + 'cha': metric_cha(actual, expected), 'wor': metric_wor(actual, expected), 'voc': metric_wor(actual, expected, vocalize=True) } @@ -201,6 +201,7 @@ def format_latex(sysname, results): print('{sysname} & {cha:.2%} & {dec:.2%} & {wor:.2%} & {voc:.2%} \\\\'.format(sysname=sysname, **results) .replace('%', '')) + def all_stats(): SYSTEMS = [ # "Nakdimon", diff --git a/experiments/partial-modern.png b/experiments/partial-modern.png new file mode 100644 index 00000000..2735d3af Binary files /dev/null and b/experiments/partial-modern.png differ diff --git a/experiments/partial_modern.py b/experiments/partial_modern.py new file mode 100644 index 00000000..676fea89 --- /dev/null +++ b/experiments/partial_modern.py @@ -0,0 +1,130 @@ +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + + +results = [ + # r , cha , dec , wor , voc + [0.1, 0.898022, 0.941791, 0.730089, 0.790349], + [0.1, 0.907983, 0.947288, 0.746686, 0.805641], + [0.1, 0.909339, 0.948013, 0.756906, 0.808338], + [0.1, 0.910473, 0.948737, 0.757871, 0.81078], + [0.1, 0.910393, 0.94853, 0.761774, 0.8147], + [0.2, 0.921098, 0.955005, 0.782228, 0.833589], + [0.2, 0.921384, 0.955211, 0.783685, 0.832101], + [0.2, 0.921859, 0.95526, 0.788762, 0.836588], + [0.2, 0.928058, 0.958799, 0.796022, 0.844013], + [0.2, 0.927519, 0.958676, 0.797943, 0.846828], + [0.3, 0.929646, 0.959993, 0.804771, 0.851261], + [0.3, 0.929712, 0.959984, 0.80654, 0.853802], + [0.3, 0.931107, 0.960884, 0.807965, 0.854223], + [0.3, 0.932764, 0.961357, 0.808793, 0.855598], + [0.3, 0.932484, 0.961512, 0.812267, 0.856918], + [0.4, 0.936546, 0.96386, 0.822286, 0.866843], + [0.4, 0.935857, 0.963207, 0.824132, 0.865466], + [0.4, 0.937018, 0.964095, 0.824572, 0.867156], + [0.4, 0.93854, 0.965087, 0.82832, 0.867609], + [0.4, 0.94117, 0.966382, 0.831514, 0.877483], + [0.5, 0.940108, 0.965908, 0.830636, 0.874788], + [0.5, 0.942086, 0.96687, 0.834623, 0.876376], + [0.5, 0.9431, 0.967516, 0.835859, 0.87881], + [0.5, 0.942016, 0.967019, 0.836506, 0.881143], + [0.5, 0.942875, 0.96751, 0.837655, 0.881893], + [0.6, 0.942889, 0.967467, 0.834149, 0.879263], + [0.6, 0.94241, 0.967134, 0.835282, 0.879975], + [0.6, 0.942999, 0.967409, 0.837874, 0.879422], + [0.6, 0.944271, 0.96841, 0.84033, 0.88526], + [0.6, 0.944287, 0.968416, 0.840918, 0.88462], + [0.7, 0.945244, 0.968931, 0.840901, 0.885507], + [0.7, 0.944691, 0.968439, 0.84286, 0.886661], + [0.7, 0.945565, 0.968863, 0.844009, 0.88787,], + [0.7, 0.94545, 0.969033, 0.84419, 0.888327], + [0.7, 0.947342, 0.970026, 0.846682, 0.890258], + [0.8, 0.94755, 0.970202, 0.847723, 0.891049], + [0.8, 0.948123, 0.970388, 0.849016, 0.891005], + [0.8, 0.947261, 0.970096, 0.849468, 0.891794], + [0.8, 0.948495, 0.970793, 0.850103, 0.892667], + [0.8, 0.947547, 0.970207, 0.850174, 0.892045], + [0.9, 0.95029, 0.971653, 0.853899, 0.898095], + [0.9, 0.949553, 0.971291, 0.855798, 0.89668,], + [0.9, 0.950043, 0.971714, 0.855912, 0.897163], + [0.9, 0.950754, 0.972088, 0.857229, 0.89916,], + [0.9, 0.95033, 0.971725, 0.857454, 0.900025], + [ 1, 0.95109, 0.972268, 0.85568, 0.898633], + [ 1, 0.950829, 0.972174, 0.856407, 0.89867,], + [ 1, 0.95127, 0.972382, 0.858379, 0.90111,], + [ 1, 0.952334, 0.972927, 0.860708, 0.903797], + [ 1, 0.951377, 0.972402, 0.861084, 0.904752], +] + +results = np.array([ + # r , cha , dec , wor , voc + [0.1, 0.898022, 0.941791, 0.730089, 0.790349], + [0.1, 0.907983, 0.947288, 0.746686, 0.805641], + [0.1, 0.909339, 0.948013, 0.756906, 0.808338], + [0.1, 0.910473, 0.948737, 0.757871, 0.810780], + [0.1, 0.910393, 0.94853, 0.761774, 0.814700], + [0.2, 0.921098, 0.955005, 0.782228, 0.833589], + [0.2, 0.921384, 0.955211, 0.783685, 0.832101], + [0.2, 0.921859, 0.95526, 0.788762, 0.836588], + [0.2, 0.928058, 0.958799, 0.796022, 0.844013], + [0.2, 0.927519, 0.958676, 0.797943, 0.846828], + [0.3, 0.929646, 0.959993, 0.804771, 0.851261], + [0.3, 0.929712, 0.959984, 0.806540, 0.853802], + [0.3, 0.931107, 0.960884, 0.807965, 0.854223], + [0.3, 0.932764, 0.961357, 0.808793, 0.855598], + [0.3, 0.932484, 0.961512, 0.812267, 0.856918], + [0.4, 0.936546, 0.96386, 0.822286, 0.866843], + [0.4, 0.935857, 0.963207, 0.824132, 0.865466], + [0.4, 0.937018, 0.964095, 0.824572, 0.867156], + [0.4, 0.93854, 0.965087, 0.828320, 0.867609], + [0.4, 0.94117, 0.966382, 0.831514, 0.877483], + [0.5, 0.940108, 0.965908, 0.830636, 0.874788], + [0.5, 0.942086, 0.96687, 0.834623, 0.876376], + [0.5, 0.9431, 0.967516, 0.835859, 0.878810], + [0.5, 0.942016, 0.967019, 0.836506, 0.881143], + [0.5, 0.942875, 0.96751, 0.837655, 0.881893], + [0.6, 0.942889, 0.967467, 0.834149, 0.879263], + [0.6, 0.94241, 0.967134, 0.835282, 0.879975], + [0.6, 0.942999, 0.967409, 0.837874, 0.879422], + [0.6, 0.944271, 0.96841, 0.84033, 0.885260], + [0.6, 0.944287, 0.968416, 0.840918, 0.884620], + [0.7, 0.945244, 0.968931, 0.840901, 0.885507], + [0.7, 0.944691, 0.968439, 0.84286, 0.886661], + [0.7, 0.945565, 0.968863, 0.844009, 0.887870], + [0.7, 0.94545, 0.969033, 0.84419, 0.888327], + [0.7, 0.947342, 0.970026, 0.846682, 0.890258], + [0.8, 0.94755, 0.970202, 0.847723, 0.891049], + [0.8, 0.948123, 0.970388, 0.849016, 0.891005], + [0.8, 0.947261, 0.970096, 0.849468, 0.891794], + [0.8, 0.948495, 0.970793, 0.850103, 0.892667], + [0.8, 0.947547, 0.970207, 0.850174, 0.892045], + [0.9, 0.95029, 0.971653, 0.853899, 0.898095], + [0.9, 0.949553, 0.971291, 0.855798, 0.896680], + [0.9, 0.950043, 0.971714, 0.855912, 0.897163], + [0.9, 0.950754, 0.972088, 0.857229, 0.899160], + [0.9, 0.95033, 0.971725, 0.857454, 0.900025], + [ 1, 0.95109, 0.972268, 0.85568, 0.898633], + [ 1, 0.950829, 0.972174, 0.856407, 0.898670], + [ 1, 0.95127, 0.972382, 0.858379, 0.901110], + [ 1, 0.952334, 0.972927, 0.860708, 0.903797], + [ 1, 0.951377, 0.972402, 0.861084, 0.904752], +]) +# rs = pd.DataFrame([[results[i, 0], results[i, 3]] for i in range(0, 50, 5)]) +#print(rs) +print(results[:, 3]) +x = np.round(results[:, 0] * (413 - 40)) +ax = sns.lineplot(x=x, y=100 * (1-results[:, 3]), label="Nakdimon", marker='o') +sns.lineplot(x=x, y=([100-91.56] * 50), label="Nakdan") + + +ax.set(xlabel="Number of modern documents in Nakdimon's training set", + ylabel='WOR error rate') +ax.set(ylim=(0, 26)) + +ax.xaxis.label.set_fontsize(12) +for l in ax.get_xticklabels(): + l.set_fontsize(12) + +plt.show() diff --git a/experiments/pretrain.py b/experiments/pretrain.py index 2571b75d..de395b70 100644 --- a/experiments/pretrain.py +++ b/experiments/pretrain.py @@ -146,8 +146,8 @@ def train_ablation(params): if mode == 'pretrain': pretrain() - elif mode == 'train_ablation': - train_ablation(PretrainedModernOnly()) + # elif mode == 'train_ablation': + # train_ablation(PretrainedModernOnly()) else: import ablations tf.config.set_visible_devices([], 'GPU') diff --git a/experiments/train.py b/experiments/train.py index ccc35c03..954e21af 100644 --- a/experiments/train.py +++ b/experiments/train.py @@ -26,24 +26,25 @@ class NakdimonParams: def name(self): return type(self).__name__ - maxlen = 80 batch_size = 64 units = 400 corpus = { - 'mix': [ + 'mix': (80, tuple([ 'hebrew_diacritized/poetry', 'hebrew_diacritized/rabanit', 'hebrew_diacritized/pre_modern' - ], - 'modern': [ + ])), + 'modern': (80, tuple([ 'hebrew_diacritized/modern', 'hebrew_diacritized/dictaTestCorpus' - ] + ])) } validation_rate = 0 + subtraining_rate = {'mix': 1, 'modern': 1} + def loss(self, y_true, y_pred): return masked_metric(tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True), y_true) @@ -74,6 +75,9 @@ def build_model(self): ] return tf.keras.Model(inputs=inp, outputs=outputs) + def initialize_weights(self, model): + return + class TrainingParams(NakdimonParams): validation_rate = 0.1 @@ -90,14 +94,17 @@ def get_xy(d): def load_data(params: NakdimonParams): data = {} - for stage_name, stage_dataset_filenames in params.corpus.items(): + for stage_name, (maxlen, stage_dataset_filenames) in params.corpus.items(): np.random.seed(2) - data[stage_name] = dataset.load_data(dataset.read_corpora(stage_dataset_filenames), - validation_rate=params.validation_rate, maxlen=params.maxlen) + data[stage_name] = dataset.load_data(tuple(dataset.read_corpora(tuple(stage_dataset_filenames))), + validation_rate=params.validation_rate, + maxlen=maxlen + # ,subtraining_rate=params.subtraining_rate[stage_name] + ) return data -def train(params: NakdimonParams, ablation=None): +def train(params: NakdimonParams, group, ablation=None): data = load_data(params) @@ -108,20 +115,24 @@ def train(params: NakdimonParams, ablation=None): config = { 'batch_size': params.batch_size, - 'maxlen': params.maxlen, 'units': params.units, 'model': model, + # 'rate_modern': params.subtraining_rate['modern'] } run = wandb.init(project="dotter", - group="ablations_final", + group=group, name=params.name, tags=[], config=config) + + params.initialize_weights(model) + with run: last_epoch = 0 for (stage, n_epochs, scheduler) in params.epoch_params(data): (train, validation) = data[stage] + if validation: with open(f'validation_files_{stage}.txt', 'w') as f: for p in validation.filenames: @@ -152,5 +163,5 @@ class Full(NakdimonParams): if __name__ == '__main__': - model = train(Full()) + model = train(Full(), 'Full') model.save(f'./final_model/final.h5') diff --git a/index.html b/index.html index 7422215c..53d3a94f 100644 --- a/index.html +++ b/index.html @@ -14,12 +14,12 @@

נקדן

-

+

- + diff --git a/main.ipynb b/main.ipynb index 10913e05..d353d2fb 100644 --- a/main.ipynb +++ b/main.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -61,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -106,7 +106,7 @@ " embed = layers.Embedding(LETTERS_SIZE, units, mask_zero=True)(inp)\n", " \n", " layer = layers.Bidirectional(layers.LSTM(units, return_sequences=True, dropout=0.1), merge_mode='sum')(embed)\n", - " # layer = layers.Bidirectional(layers.LSTM(units, return_sequences=True, dropout=0.1), merge_mode='sum')(layer)\n", + " layer = layers.Bidirectional(layers.LSTM(units, return_sequences=True, dropout=0.1), merge_mode='sum')(layer)\n", " layer = layers.Dense(units)(layer)\n", "\n", " outputs = [\n", @@ -168,13 +168,122 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: WANDB_MODE=dryrun\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "wandb: Offline run mode, not syncing to the cloud.\n", + "wandb: W&B is disabled in this directory. Run `wandb on` to enable cloud syncing.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1636/1636 [==============================] - 278s 148ms/step - loss: 1.1109 - N_loss: 0.6972 - D_loss: 0.2073 - S_loss: 0.2064 - N_accuracy: 0.7610 - D_accuracy: 0.9198 - S_accuracy: 0.9426 - val_loss: 0.5110 - val_N_loss: 0.2298 - val_D_loss: 0.1655 - val_S_loss: 0.1157 - val_N_accuracy: 0.9263 - val_D_accuracy: 0.9420 - val_S_accuracy: 0.9722\n", + "letters: 85.48%, decisions: 91.57%, words: 64.75%\n", + "Epoch 2/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.2697 - N_loss: 0.1695 - D_loss: 0.0662 - S_loss: 0.0340 - N_accuracy: 0.9425 - D_accuracy: 0.9753 - S_accuracy: 0.9914 - val_loss: 0.2176 - val_N_loss: 0.1308 - val_D_loss: 0.0517 - val_S_loss: 0.0351 - val_N_accuracy: 0.9566 - val_D_accuracy: 0.9818 - val_S_accuracy: 0.9914\n", + "Epoch 3/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.1543 - N_loss: 0.0970 - D_loss: 0.0408 - S_loss: 0.0165 - N_accuracy: 0.9670 - D_accuracy: 0.9850 - S_accuracy: 0.9957 - val_loss: 0.2053 - val_N_loss: 0.1199 - val_D_loss: 0.0473 - val_S_loss: 0.0381 - val_N_accuracy: 0.9623 - val_D_accuracy: 0.9838 - val_S_accuracy: 0.9915\n", + "Epoch 4/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.1159 - N_loss: 0.0728 - D_loss: 0.0325 - S_loss: 0.0106 - N_accuracy: 0.9751 - D_accuracy: 0.9881 - S_accuracy: 0.9969 - val_loss: 0.2012 - val_N_loss: 0.1177 - val_D_loss: 0.0459 - val_S_loss: 0.0376 - val_N_accuracy: 0.9631 - val_D_accuracy: 0.9836 - val_S_accuracy: 0.9915\n", + "Epoch 5/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.0675 - N_loss: 0.0420 - D_loss: 0.0212 - S_loss: 0.0044 - N_accuracy: 0.9856 - D_accuracy: 0.9924 - S_accuracy: 0.9987 - val_loss: 0.2025 - val_N_loss: 0.1141 - val_D_loss: 0.0470 - val_S_loss: 0.0414 - val_N_accuracy: 0.9682 - val_D_accuracy: 0.9857 - val_S_accuracy: 0.9935\n", + "Epoch 6/6\n", + "255/255 [==============================] - 38s 148ms/step - loss: 0.0460 - N_loss: 0.0288 - D_loss: 0.0156 - S_loss: 0.0016 - N_accuracy: 0.9903 - D_accuracy: 0.9944 - S_accuracy: 0.9995 - val_loss: 0.2055 - val_N_loss: 0.1148 - val_D_loss: 0.0478 - val_S_loss: 0.0429 - val_N_accuracy: 0.9686 - val_D_accuracy: 0.9858 - val_S_accuracy: 0.9935\n", + "letters: 95.92%, decisions: 97.68%, words: 88.22%\n" + ] + }, + { + "data": { + "text/html": [ + "
Waiting for W&B process to finish, PID 21716
Program ended successfully." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find user logs for this run at: wandb\\offline-run-20210121_210719-2z1mn0qk\\logs\\debug.log" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find internal logs for this run at: wandb\\offline-run-20210121_210719-2z1mn0qk\\logs\\debug-internal.log" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Run summary:


\n", + "
loss0.04595
N_loss0.02882
D_loss0.01556
S_loss0.00158
N_accuracy0.99031
D_accuracy0.99442
S_accuracy0.99954
_step69
_runtime479
_timestamp1611256519
epoch5
val_loss0.20546
val_N_loss0.1148
val_D_loss0.0478
val_S_loss0.04286
val_N_accuracy0.96859
val_D_accuracy0.98577
val_S_accuracy0.99345
best_val_loss0.20123
best_epoch3
index0
letters0.95923
decisions0.9768
words0.88224
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "

Run history:


\n", + "
loss█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
N_loss█▆▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
D_loss█▅▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
S_loss█▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
N_accuracy▁▃▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████████████████
D_accuracy▁▆▇▇▇▇▇▇▇▇██████████████████████████████
S_accuracy▁▇▇▇▇███████████████████████████████████
_step▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
_timestamp▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
epoch▁▂▄▅▇█
val_loss█▁▁▁▁▁
val_N_loss█▂▁▁▁▁
val_D_loss█▁▁▁▁▁
val_S_loss█▁▁▁▂▂
val_N_accuracy▁▆▇▇██
val_D_accuracy▁▇████
val_S_accuracy▁▇▇▇██
index
letters
decisions
words

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "wandb: You can sync this run to the cloud by running:\n", + "wandb: wandb sync wandb\\offline-run-20210121_210719-2z1mn0qk\n" + ] + } + ], "source": [ - "%env WANDB_MODE run\n", + "%env WANDB_MODE dryrun\n", "\n", "def experiment(n):\n", " BATCH_SIZE = 64\n", @@ -230,22 +339,10 @@ " run.log({'index': 0, 'letters': letters, 'decisions': decisions, 'words': words})\n", " return model\n", "\n", - "for n in range(5):\n", + "for n in range(1):\n", " model = experiment(n) # 20-30-20-5-1: 88.08-88.16" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "scrolled": true - }, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -491,4 +588,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +}