diff --git a/modules/initialize.py b/modules/initialize.py index 0365bbb3093..0f1a2407f8d 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -14,6 +14,7 @@ def imports(): import torch # noqa: F401 startup_timer.record("import torch") + from modules import patch_hf_hub_download # noqa: F401 import pytorch_lightning # noqa: F401 startup_timer.record("import torch") warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") diff --git a/modules/patch_hf_hub_download.py b/modules/patch_hf_hub_download.py new file mode 100644 index 00000000000..b4d5e3da4fa --- /dev/null +++ b/modules/patch_hf_hub_download.py @@ -0,0 +1,41 @@ +from modules.patches import patch +from modules.errors import report +from inspect import signature +from functools import wraps + +try: + from huggingface_hub.utils import LocalEntryNotFoundError + from huggingface_hub import file_download + + def try_local_files_only(func): + if (param := signature(func).parameters.get('local_files_only', None)) and not param.kind == param.KEYWORD_ONLY: + raise ValueError(f'{func.__name__} does not have keyword-only parameter "local_files_only"') + + @wraps(func) + def wrapper(*args, **kwargs): + try: + from modules.shared import opts + try_offline_mode = not kwargs.get('local_files_only') and opts.hd_dl_local_first + except Exception: + report('Error in try_local_files_only - skip try_local_files_only', exc_info=True) + try_offline_mode = False + + if try_offline_mode: + try: + return func(*args, **{**kwargs, 'local_files_only': True}) + except LocalEntryNotFoundError: + pass + except Exception: + report('Unexpected exception in try_local_files_only - retry without patch', exc_info=True) + + return func(*args, **kwargs) + + return wrapper + + try: + patch(__name__, file_download, 'hf_hub_download', try_local_files_only(file_download.hf_hub_download)) + except RuntimeError: + pass # already patched + +except Exception: + report('Error patching hf_hub_download', exc_info=True) diff --git a/modules/shared_options.py b/modules/shared_options.py index 9f4520274b1..c370f880d2f 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -128,6 +128,7 @@ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), + "hd_dl_local_first": OptionInfo(False, "Prevent connecting to huggingface for assets if cache is available").info('this will also prevent assets from being updated'), })) options_templates.update(options_section(('profiler', "Profiler", "system"), {