Skip to content

Commit

Permalink
Merge pull request #81 from j3mdamas/feature/issue-70/allow-model-loc…
Browse files Browse the repository at this point in the history
…ation-definition

Encapsulate model download in a function
  • Loading branch information
Kohulan authored Jan 5, 2024
2 parents d6bb004 + 2198900 commit 208752a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 17 deletions.
16 changes: 2 additions & 14 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,13 @@
# Set path
default_path = pystow.join("DECIMER-V2")

# model download location
model_url = "https://zenodo.org/record/8300489/files/models.zip"
model_path = str(default_path) + "/DECIMER_model/"

# download models to a default location
if (
os.path.exists(model_path)
and os.stat(model_path + "/saved_model.pb").st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)
utils.ensure_model(default_path=default_path)

# Load important pickle files which consists the tokenizers and the maxlength setting

tokenizer = pickle.load(
open(
default_path.as_posix() + "/DECIMER_model/assets/tokenizer_SMILES.pkl",
os.path.join(default_path.as_posix(), "DECIMER_model", "assets", "tokenizer_SMILES.pkl"),
"rb",
)
)
Expand Down
3 changes: 1 addition & 2 deletions DECIMER/efficientnetv2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from absl import logging
import numpy as np
import tensorflow as tf
import tensorflow_addons.layers as tfa_layers

from tensorflow.python.tpu import (
tpu_function,
Expand Down Expand Up @@ -236,7 +235,7 @@ def normalization(
):
"""Normalization after conv layers."""
if norm_type == "gn":
return tfa_layers.GroupNormalization(groups, axis, epsilon, name=name)
return tf.keras.layers.GroupNormalization(groups, axis, epsilon, name=name)

if norm_type == "tpu_bn":
return TpuBatchNormalization(
Expand Down
26 changes: 26 additions & 0 deletions DECIMER/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import os
import DECIMER.config as config

pattern = "R([0-9]*)|X([0-9]*)|Y([0-9]*)|Z([0-9]*)"
add_space_re = "^(\W+)|(\W+)$"
Expand Down Expand Up @@ -57,3 +59,27 @@ def decoder(predictions):
.replace("§", "0")
)
return modified

def ensure_model(
default_path: str,
model_url: str = "https://zenodo.org/record/8300489/files/models.zip"
):
"""Function to ensure model is present locally
Convenient function to ensure model download before usage
Args:
default path (str): default path for DECIMER data
model_url (str): trained model url for downloading
"""

model_path = os.path.join(default_path.as_posix(), "DECIMER_model")

if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"efficientnet",
"selfies",
"pyyaml",
"tensorflow-addons",
],
package_data={"DECIMER": ["repack/*.*", "efficientnetv2/*.*", "Utils/*.*"]},
classifiers=[
Expand Down

0 comments on commit 208752a

Please sign in to comment.