diff --git a/modules/config.py b/modules/config.py index 29a16d6dc..e3c427d2c 100644 --- a/modules/config.py +++ b/modules/config.py @@ -2,13 +2,14 @@ import json import math import numbers + import args_manager import tempfile import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.extra_utils import makedirs_with_log, get_files_from_folder +from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_eval_env_var from modules.flags import OutputFormat, Performance, MetadataScheme @@ -200,7 +201,7 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa path_outputs = get_path_output() -def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): +def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, expected_type=None): global config_dict, visited_keys if key not in visited_keys: @@ -208,6 +209,7 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ v = os.getenv(key) if v is not None: + v = try_eval_env_var(v, expected_type) print(f"Environment: {key} = {v}") config_dict[key] = v @@ -252,41 +254,49 @@ def init_temp_path(path: str | None, default_path: str) -> str: key='temp_path', default_value=default_temp_path, validator=lambda x: isinstance(x, str), + expected_type=str ), default_temp_path) temp_path_cleanup_on_launch = get_config_item_or_set_default( key='temp_path_cleanup_on_launch', default_value=True, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_base_model_name = default_model = get_config_item_or_set_default( key='default_model', default_value='model.safetensors', - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) previous_default_models = get_config_item_or_set_default( key='previous_default_models', default_value=[], - validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x) + validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x), + expected_type=list ) default_refiner_model_name = default_refiner = get_config_item_or_set_default( key='default_refiner', default_value='None', - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) default_refiner_switch = get_config_item_or_set_default( key='default_refiner_switch', default_value=0.8, - validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1 + validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1, + expected_type=numbers.Number ) default_loras_min_weight = get_config_item_or_set_default( key='default_loras_min_weight', default_value=-2, - validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 + validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10, + expected_type=numbers.Number ) default_loras_max_weight = get_config_item_or_set_default( key='default_loras_max_weight', default_value=2, - validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 + validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10, + expected_type=numbers.Number ) default_loras = get_config_item_or_set_default( key='default_loras', @@ -320,38 +330,45 @@ def init_temp_path(path: str | None, default_path: str) -> str: validator=lambda x: isinstance(x, list) and all( len(y) == 3 and isinstance(y[0], bool) and isinstance(y[1], str) and isinstance(y[2], numbers.Number) or len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) - for y in x) + for y in x), + expected_type=list ) default_loras = [(y[0], y[1], y[2]) if len(y) == 3 else (True, y[0], y[1]) for y in default_loras] default_max_lora_number = get_config_item_or_set_default( key='default_max_lora_number', default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5, - validator=lambda x: isinstance(x, int) and x >= 1 + validator=lambda x: isinstance(x, int) and x >= 1, + expected_type=int ) default_cfg_scale = get_config_item_or_set_default( key='default_cfg_scale', default_value=7.0, - validator=lambda x: isinstance(x, numbers.Number) + validator=lambda x: isinstance(x, numbers.Number), + expected_type=numbers.Number ) default_sample_sharpness = get_config_item_or_set_default( key='default_sample_sharpness', default_value=2.0, - validator=lambda x: isinstance(x, numbers.Number) + validator=lambda x: isinstance(x, numbers.Number), + expected_type=numbers.Number ) default_sampler = get_config_item_or_set_default( key='default_sampler', default_value='dpmpp_2m_sde_gpu', - validator=lambda x: x in modules.flags.sampler_list + validator=lambda x: x in modules.flags.sampler_list, + expected_type=str ) default_scheduler = get_config_item_or_set_default( key='default_scheduler', default_value='karras', - validator=lambda x: x in modules.flags.scheduler_list + validator=lambda x: x in modules.flags.scheduler_list, + expected_type=str ) default_vae = get_config_item_or_set_default( key='default_vae', default_value=modules.flags.default_vae, - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) default_styles = get_config_item_or_set_default( key='default_styles', @@ -360,121 +377,144 @@ def init_temp_path(path: str | None, default_path: str) -> str: "Fooocus Enhance", "Fooocus Sharp" ], - validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) + validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x), + expected_type=list ) default_prompt_negative = get_config_item_or_set_default( key='default_prompt_negative', default_value='', validator=lambda x: isinstance(x, str), - disable_empty_as_none=True + disable_empty_as_none=True, + expected_type=str ) default_prompt = get_config_item_or_set_default( key='default_prompt', default_value='', validator=lambda x: isinstance(x, str), - disable_empty_as_none=True + disable_empty_as_none=True, + expected_type=str ) default_performance = get_config_item_or_set_default( key='default_performance', default_value=Performance.SPEED.value, - validator=lambda x: x in Performance.list() + validator=lambda x: x in Performance.list(), + expected_type=str ) default_advanced_checkbox = get_config_item_or_set_default( key='default_advanced_checkbox', default_value=False, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_max_image_number = get_config_item_or_set_default( key='default_max_image_number', default_value=32, - validator=lambda x: isinstance(x, int) and x >= 1 + validator=lambda x: isinstance(x, int) and x >= 1, + expected_type=int ) default_output_format = get_config_item_or_set_default( key='default_output_format', default_value='png', - validator=lambda x: x in OutputFormat.list() + validator=lambda x: x in OutputFormat.list(), + expected_type=str ) default_image_number = get_config_item_or_set_default( key='default_image_number', default_value=2, - validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number + validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number, + expected_type=int ) checkpoint_downloads = get_config_item_or_set_default( key='checkpoint_downloads', default_value={}, - validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()), + expected_type=dict ) lora_downloads = get_config_item_or_set_default( key='lora_downloads', default_value={}, - validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()), + expected_type=dict ) embeddings_downloads = get_config_item_or_set_default( key='embeddings_downloads', default_value={}, - validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()), + expected_type=dict ) available_aspect_ratios = get_config_item_or_set_default( key='available_aspect_ratios', default_value=modules.flags.sdxl_aspect_ratios, - validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1 + validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1, + expected_type=list ) default_aspect_ratio = get_config_item_or_set_default( key='default_aspect_ratio', default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0], - validator=lambda x: x in available_aspect_ratios + validator=lambda x: x in available_aspect_ratios, + expected_type=str ) default_inpaint_engine_version = get_config_item_or_set_default( key='default_inpaint_engine_version', default_value='v2.6', - validator=lambda x: x in modules.flags.inpaint_engine_versions + validator=lambda x: x in modules.flags.inpaint_engine_versions, + expected_type=str ) default_cfg_tsnr = get_config_item_or_set_default( key='default_cfg_tsnr', default_value=7.0, - validator=lambda x: isinstance(x, numbers.Number) + validator=lambda x: isinstance(x, numbers.Number), + expected_type=numbers.Number ) default_clip_skip = get_config_item_or_set_default( key='default_clip_skip', default_value=2, - validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max + validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max, + expected_type=int ) default_overwrite_step = get_config_item_or_set_default( key='default_overwrite_step', default_value=-1, - validator=lambda x: isinstance(x, int) + validator=lambda x: isinstance(x, int), + expected_type=int ) default_overwrite_switch = get_config_item_or_set_default( key='default_overwrite_switch', default_value=-1, - validator=lambda x: isinstance(x, int) + validator=lambda x: isinstance(x, int), + expected_type=int ) example_inpaint_prompts = get_config_item_or_set_default( key='example_inpaint_prompts', default_value=[ 'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes' ], - validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) + validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x), + expected_type=list ) default_black_out_nsfw = get_config_item_or_set_default( key='default_black_out_nsfw', default_value=False, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_save_metadata_to_images = get_config_item_or_set_default( key='default_save_metadata_to_images', default_value=False, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_metadata_scheme = get_config_item_or_set_default( key='default_metadata_scheme', default_value=MetadataScheme.FOOOCUS.value, - validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x] + validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x], + expected_type=str ) metadata_created_by = get_config_item_or_set_default( key='metadata_created_by', default_value='', - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) example_inpaint_prompts = [[x] for x in example_inpaint_prompts] diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 9906c8202..c2dfa8104 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -1,4 +1,6 @@ import os +from ast import literal_eval + def makedirs_with_log(path): try: @@ -24,3 +26,16 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): filenames.append(path) return filenames + + +def try_eval_env_var(value: str, expected_type=None): + try: + value_eval = value + if expected_type is bool: + value_eval = value.title() + value_eval = literal_eval(value_eval) + if expected_type is not None and not isinstance(value_eval, expected_type): + return value + return value_eval + except: + return value diff --git a/tests/test_extra_utils.py b/tests/test_extra_utils.py new file mode 100644 index 000000000..a849aa16d --- /dev/null +++ b/tests/test_extra_utils.py @@ -0,0 +1,74 @@ +import numbers +import os +import unittest + +import modules.flags +from modules import extra_utils + + +class TestUtils(unittest.TestCase): + def test_try_eval_env_var(self): + test_cases = [ + { + "input": ("foo", str), + "output": "foo" + }, + { + "input": ("1", int), + "output": 1 + }, + { + "input": ("1.0", float), + "output": 1.0 + }, + { + "input": ("1", numbers.Number), + "output": 1 + }, + { + "input": ("1.0", numbers.Number), + "output": 1.0 + }, + { + "input": ("true", bool), + "output": True + }, + { + "input": ("True", bool), + "output": True + }, + { + "input": ("false", bool), + "output": False + }, + { + "input": ("False", bool), + "output": False + }, + { + "input": ("True", str), + "output": "True" + }, + { + "input": ("False", str), + "output": "False" + }, + { + "input": ("['a', 'b', 'c']", list), + "output": ['a', 'b', 'c'] + }, + { + "input": ("{'a':1}", dict), + "output": {'a': 1} + }, + { + "input": ("('foo', 1)", tuple), + "output": ('foo', 1) + } + ] + + for test in test_cases: + value, expected_type = test["input"] + expected = test["output"] + actual = extra_utils.try_eval_env_var(value, expected_type) + self.assertEqual(expected, actual)