From c120908f655d2bd2d5e657b208a8cf0114fa8e24 Mon Sep 17 00:00:00 2001 From: Tobias Bernhard Date: Tue, 26 Sep 2023 14:41:00 +0200 Subject: [PATCH 1/3] replace tensorflow_addons.GroupNormalization with tf.keras.layers.GroupNormalization --- DECIMER/efficientnetv2/utils.py | 3 +-- setup.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/DECIMER/efficientnetv2/utils.py b/DECIMER/efficientnetv2/utils.py index 2446d46..a83ff7e 100644 --- a/DECIMER/efficientnetv2/utils.py +++ b/DECIMER/efficientnetv2/utils.py @@ -19,7 +19,6 @@ from absl import logging import numpy as np import tensorflow as tf -import tensorflow_addons.layers as tfa_layers from tensorflow.python.tpu import ( tpu_function, @@ -236,7 +235,7 @@ def normalization( ): """Normalization after conv layers.""" if norm_type == "gn": - return tfa_layers.GroupNormalization(groups, axis, epsilon, name=name) + return tf.keras.layers.GroupNormalization(groups, axis, epsilon, name=name) if norm_type == "tpu_bn": return TpuBatchNormalization( diff --git a/setup.py b/setup.py index 6f2b4f2..527611b 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,6 @@ "efficientnet", "selfies", "pyyaml", - "tensorflow-addons", ], package_data={"DECIMER": ["repack/*.*", "efficientnetv2/*.*", "Utils/*.*"]}, classifiers=[ From f1f3e68d9589a7d85bf22705931ac9d9e3a06bca Mon Sep 17 00:00:00 2001 From: Joao Damas CHST Date: Wed, 13 Sep 2023 18:00:55 +0200 Subject: [PATCH 2/3] pythonifies paths for model --- DECIMER/decimer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/DECIMER/decimer.py b/DECIMER/decimer.py index f28005e..504fb78 100644 --- a/DECIMER/decimer.py +++ b/DECIMER/decimer.py @@ -30,12 +30,12 @@ # model download location model_url = "https://zenodo.org/record/8300489/files/models.zip" -model_path = str(default_path) + "/DECIMER_model/" +model_path = os.path.join(default_path.as_posix(), "DECIMER_model") # download models to a default location if ( os.path.exists(model_path) - and os.stat(model_path + "/saved_model.pb").st_size != 28080309 + and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309 ): shutil.rmtree(model_path) config.download_trained_weights(model_url, default_path) @@ -43,10 +43,9 @@ config.download_trained_weights(model_url, default_path) # Load important pickle files which consists the tokenizers and the maxlength setting - tokenizer = pickle.load( open( - default_path.as_posix() + "/DECIMER_model/assets/tokenizer_SMILES.pkl", + os.path.join(default_path.as_posix(), "DECIMER_model", "assets", "tokenizer_SMILES.pkl"), "rb", ) ) From 2198900804060884afbda9c83c49bba5a5290ef3 Mon Sep 17 00:00:00 2001 From: Joao Damas CHST Date: Thu, 4 Jan 2024 17:36:48 +0100 Subject: [PATCH 3/3] encapsulates model download into a function --- DECIMER/decimer.py | 13 +------------ DECIMER/utils.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/DECIMER/decimer.py b/DECIMER/decimer.py index 504fb78..88c0021 100644 --- a/DECIMER/decimer.py +++ b/DECIMER/decimer.py @@ -28,19 +28,8 @@ # Set path default_path = pystow.join("DECIMER-V2") -# model download location -model_url = "https://zenodo.org/record/8300489/files/models.zip" -model_path = os.path.join(default_path.as_posix(), "DECIMER_model") - # download models to a default location -if ( - os.path.exists(model_path) - and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309 -): - shutil.rmtree(model_path) - config.download_trained_weights(model_url, default_path) -elif not os.path.exists(model_path): - config.download_trained_weights(model_url, default_path) +utils.ensure_model(default_path=default_path) # Load important pickle files which consists the tokenizers and the maxlength setting tokenizer = pickle.load( diff --git a/DECIMER/utils.py b/DECIMER/utils.py index 37ad196..d9804a3 100644 --- a/DECIMER/utils.py +++ b/DECIMER/utils.py @@ -1,4 +1,6 @@ import re +import os +import DECIMER.config as config pattern = "R([0-9]*)|X([0-9]*)|Y([0-9]*)|Z([0-9]*)" add_space_re = "^(\W+)|(\W+)$" @@ -57,3 +59,27 @@ def decoder(predictions): .replace("ยง", "0") ) return modified + +def ensure_model( + default_path: str, + model_url: str = "https://zenodo.org/record/8300489/files/models.zip" +): + """Function to ensure model is present locally + + Convenient function to ensure model download before usage + + Args: + default path (str): default path for DECIMER data + model_url (str): trained model url for downloading + """ + + model_path = os.path.join(default_path.as_posix(), "DECIMER_model") + + if ( + os.path.exists(model_path) + and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309 + ): + shutil.rmtree(model_path) + config.download_trained_weights(model_url, default_path) + elif not os.path.exists(model_path): + config.download_trained_weights(model_url, default_path)