Skip to content

Commit

Permalink
encapsulates model download into a function
Browse files Browse the repository at this point in the history
  • Loading branch information
j3mdamas committed Jan 4, 2024
1 parent f1f3e68 commit 2198900
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
13 changes: 1 addition & 12 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,8 @@
# Set path
default_path = pystow.join("DECIMER-V2")

# model download location
model_url = "https://zenodo.org/record/8300489/files/models.zip"
model_path = os.path.join(default_path.as_posix(), "DECIMER_model")

# download models to a default location
if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)
utils.ensure_model(default_path=default_path)

# Load important pickle files which consists the tokenizers and the maxlength setting
tokenizer = pickle.load(
Expand Down
26 changes: 26 additions & 0 deletions DECIMER/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import os
import DECIMER.config as config

pattern = "R([0-9]*)|X([0-9]*)|Y([0-9]*)|Z([0-9]*)"
add_space_re = "^(\W+)|(\W+)$"
Expand Down Expand Up @@ -57,3 +59,27 @@ def decoder(predictions):
.replace("§", "0")
)
return modified

def ensure_model(
default_path: str,
model_url: str = "https://zenodo.org/record/8300489/files/models.zip"
):
"""Function to ensure model is present locally
Convenient function to ensure model download before usage
Args:
default path (str): default path for DECIMER data
model_url (str): trained model url for downloading
"""

model_path = os.path.join(default_path.as_posix(), "DECIMER_model")

if (
os.path.exists(model_path)
and os.stat(os.path.join(model_path, "saved_model.pb")).st_size != 28080309
):
shutil.rmtree(model_path)
config.download_trained_weights(model_url, default_path)
elif not os.path.exists(model_path):
config.download_trained_weights(model_url, default_path)

0 comments on commit 2198900

Please sign in to comment.