From e67e7e0b248539f5e37480a95bb576331026d5e7 Mon Sep 17 00:00:00 2001 From: derolol <837456540@qq.com> Date: Thu, 12 Sep 2024 15:43:05 +0000 Subject: [PATCH] Optimize the Startup Configuration Process for the Hunyuan DiT Model --- .gitignore | 7 + hydit/config_engine.py | 763 ++++++++++++++++++ hydit/configs/README.md | 32 + hydit/configs/base/model/controlnet_canny.py | 8 + hydit/configs/base/model/diffusion_v_pred.py | 13 + hydit/configs/base/model/dit_g2_1024_p.py | 12 + hydit/configs/base/model/lora_r64.py | 8 + hydit/configs/base/schedule/default.py | 60 ++ hydit/configs/base/schedule/inference.py | 28 + hydit/configs/base/schedule/train_full.py | 60 ++ hydit/configs/base/schedule/train_lora.py | 60 ++ .../train/train_full_dit_g2_1024p_multi.py | 59 ++ .../train/train_full_dit_g2_1024p_single.py | 59 ++ ...train_full_dit_g2_1024p_single_no_flash.py | 59 ++ .../train/train_lora_dit_g2_1024p_single.py | 57 ++ hydit/modules/models.py | 2 + hydit/train_deepspeed.py | 5 +- hydit/train_deepspeed.sh | 1 + 18 files changed, 1292 insertions(+), 1 deletion(-) create mode 100644 hydit/config_engine.py create mode 100644 hydit/configs/README.md create mode 100644 hydit/configs/base/model/controlnet_canny.py create mode 100644 hydit/configs/base/model/diffusion_v_pred.py create mode 100644 hydit/configs/base/model/dit_g2_1024_p.py create mode 100644 hydit/configs/base/model/lora_r64.py create mode 100644 hydit/configs/base/schedule/default.py create mode 100644 hydit/configs/base/schedule/inference.py create mode 100644 hydit/configs/base/schedule/train_full.py create mode 100644 hydit/configs/base/schedule/train_lora.py create mode 100644 hydit/configs/train/train_full_dit_g2_1024p_multi.py create mode 100644 hydit/configs/train/train_full_dit_g2_1024p_single.py create mode 100644 hydit/configs/train/train_full_dit_g2_1024p_single_no_flash.py create mode 100644 hydit/configs/train/train_lora_dit_g2_1024p_single.py create mode 100644 hydit/train_deepspeed.sh diff --git a/.gitignore b/.gitignore index 1a80114..c9292b5 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,10 @@ trt/activate.sh trt/deactivate.sh *.onnx ckpts/ + +# Data +dataset/ +*.pkl + +# Log +log_EXP \ No newline at end of file diff --git a/hydit/config_engine.py b/hydit/config_engine.py new file mode 100644 index 0000000..df60892 --- /dev/null +++ b/hydit/config_engine.py @@ -0,0 +1,763 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import os +import os.path as osp +import platform +import shutil +import tempfile +import types +import warnings +from packaging.version import parse + +from collections import OrderedDict +from pathlib import Path +from typing import Any, Optional, Sequence, Tuple, Union + +from addict import Dict + +import argparse +import deepspeed + +from constants import * +from modules.models import HUNYUAN_DIT_CONFIG + +BASE_KEY = '_base_' +DELETE_KEY = '_delete_' +DEPRECATION_KEY = '_deprecation_' +RESERVED_KEYS = [] + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Defaults to 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + assert 'parrots' not in version_str + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) # type: ignore + else: + release.extend([0, 0]) + return tuple(release) + + +def _configdict2string(cfg_dict, dict_type=None): + if isinstance(cfg_dict, dict): + dict_type = dict_type or type(cfg_dict) + return dict_type( + {k: _configdict2string(v, dict_type) + for k, v in dict.items(cfg_dict)}) + elif isinstance(cfg_dict, (tuple, list)): + return type(cfg_dict)(_configdict2string(v, dict_type) for v in cfg_dict) + else: + return cfg_dict + + +class ConfigDict(Dict): + """A dictionary for config which has the same interface as python's built- + in dictionary and can be used as a normal dictionary. + + The Config class would transform the nested fields (dictionary-like fields) + in config file into ``ConfigDict``. + """ + + def __init__(__self, *args, **kwargs): + object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) + object.__setattr__(__self, '__key', kwargs.pop('__key', None)) + object.__setattr__(__self, '__frozen', False) + for arg in args: + if not arg: + continue + if isinstance(arg, ConfigDict): + for key, val in dict.items(arg): + __self[key] = __self._hook(val) + elif isinstance(arg, dict): + for key, val in arg.items(): + __self[key] = __self._hook(val) + elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): + __self[arg[0]] = __self._hook(arg[1]) + else: + for key, val in iter(arg): + __self[key] = __self._hook(val) + + for key, val in dict.items(kwargs): + __self[key] = __self._hook(val) + + def __missing__(self, name): + raise KeyError(name) + + def __getattr__(self, name): + try: + value = super().__getattr__(name) + except KeyError: + raise AttributeError(f"'{self.__class__.__name__}' object has no " + f"attribute '{name}'") + except Exception as e: + raise e + else: + return value + + @classmethod + def _hook(cls, item): + # avoid to convert user defined dict to ConfigDict. + if type(item) in (dict, OrderedDict): + return cls(item) + elif isinstance(item, (list, tuple)): + return type(item)(cls._hook(elem) for elem in item) + return item + + def __setattr__(self, name, value): + value = self._hook(value) + return super().__setattr__(name, value) + + def __setitem__(self, name, value): + value = self._hook(value) + return super().__setitem__(name, value) + + def __getitem__(self, key): + return super().__getitem__(key) + + def __deepcopy__(self, memo): + other = self.__class__() + memo[id(self)] = other + for key, value in super().items(): + other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) + return other + + def __copy__(self): + other = self.__class__() + for key, value in super().items(): + other[key] = value + return other + + copy = __copy__ + + def __iter__(self): + # Implement `__iter__` to overwrite the unpacking operator `**cfg_dict` + # to get the built object + return iter(self.keys()) + + def get(self, key: str, default: Optional[Any] = None) -> Any: + """Get the value of the key. + + Args: + key (str): The key. + default (any, optional): The default value. Defaults to None. + + Returns: + Any: The value of the key. + """ + return super().get(key, default) + + def pop(self, key, default=None): + """Pop the value of the key. + + Args: + key (str): The key. + default (any, optional): The default value. Defaults to None. + + Returns: + Any: The value of the key. + """ + return super().pop(key, default) + + def update(self, *args, **kwargs) -> None: + other = {} + if args: + if len(args) > 1: + raise TypeError('update only accept one positional argument') + for key, value in dict.items(args[0]): + other[key] = value + + for key, value in dict(kwargs).items(): + other[key] = value + for k, v in other.items(): + if ((k not in self) or (not isinstance(self[k], dict)) + or (not isinstance(v, dict))): + self[k] = self._hook(v) + else: + self[k].update(v) + + def values(self): + """Yield the values of the dictionary. + """ + values = [] + for value in super().values(): + values.append(value) + return values + + def items(self): + """Yield the keys and values of the dictionary. + """ + items = [] + for key, value in super().items(): + items.append((key, value)) + return items + + def merge(self, other: dict): + """Merge another dictionary into current dictionary. + + Args: + other (dict): Another dictionary. + """ + default = object() + + def _merge_a_into_b(a, b): + if isinstance(a, dict): + if not isinstance(b, dict): + a.pop(DELETE_KEY, None) + return a + if a.pop(DELETE_KEY, False): + b.clear() + all_keys = list(b.keys()) + list(a.keys()) + return { + key: + _merge_a_into_b(a.get(key, default), b.get(key, default)) + for key in all_keys if key != DELETE_KEY + } + else: + return a if a is not default else b + + merged = _merge_a_into_b(copy.deepcopy(other), copy.deepcopy(self)) + self.clear() + for key, value in merged.items(): + self[key] = value + + def __reduce_ex__(self, proto): + # Override __reduce_ex__ to avoid `self.items` will be + # called by CPython interpreter during pickling. See more details in + # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501 + if digit_version(platform.python_version()) < digit_version('3.8'): + return (self.__class__, ({k: v + for k, v in super().items()}, ), None, + None, None) + else: + return (self.__class__, ({k: v + for k, v in super().items()}, ), None, + None, None, None) + + def __eq__(self, other): + if isinstance(other, ConfigDict): + return other.to_dict() == self.to_dict() + elif isinstance(other, dict): + return {k: v for k, v in self.items()} == other + else: + return False + + def to_dict(self): + """Convert the ConfigDict to a normal dictionary recursively.""" + return _configdict2string(self, dict_type=dict) + + +class RemoveAssignFromAST(ast.NodeTransformer): + """Remove Assign node if the target's name match the key. + + Args: + key (str): The target name of the Assign node. + """ + + def __init__(self, key): + self.key = key + + def visit_Assign(self, node): + if (isinstance(node.targets[0], ast.Name) + and node.targets[0].id == self.key): + return None + else: + return node + + +class ConfigParsingError(RuntimeError): + """Raise error when failed to parse pure Python style config files.""" + + +class Config: + + def __init__(self, + cfg_dict: dict = None, + cfg_text: Optional[str] = None, + filename: Optional[Union[str, Path]] = None): + filename = str(filename) if isinstance(filename, Path) else filename + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError('cfg_dict must be a dict, but ' + f'got {type(cfg_dict)}') + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f'{key} is reserved for config file') + + if not isinstance(cfg_dict, ConfigDict): + cfg_dict = ConfigDict(cfg_dict) + + super().__setattr__('_cfg_dict', cfg_dict) + super().__setattr__('_filename', filename) + + if cfg_text: + text = cfg_text + elif filename: + with open(filename, encoding='utf-8') as f: + text = f.read() + else: + text = '' + super().__setattr__('_text', text) + + @staticmethod + def _validate_py_syntax(filename: str): + """Validate syntax of python config. + + Args: + filename (str): Filename of python config file. + """ + with open(filename, encoding='utf-8') as f: + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError('There are syntax errors in config ' + f'file {filename}: {e}') + + + @staticmethod + def _get_base_files(filename: str) -> list: + """Get the base config file. + + Args: + filename (str): The config file. + + Raises: + TypeError: Name of config file. + + Returns: + list: A list of base config. + """ + file_format = osp.splitext(filename)[1] + if file_format == '.py': + Config._validate_py_syntax(filename) + with open(filename, encoding='utf-8') as f: + parsed_codes = ast.parse(f.read()).body + + def is_base_line(c): + return (isinstance(c, ast.Assign) + and isinstance(c.targets[0], ast.Name) + and c.targets[0].id == BASE_KEY) + + base_code = next((c for c in parsed_codes + if is_base_line(c)), + None) + if base_code is not None: + base_code = ast.Expression( # type: ignore + body=base_code.value) # type: ignore + base_files = eval(compile(base_code, '', mode='eval')) + else: + base_files = [] + else: + raise ConfigParsingError( + 'The config type should be py, ' + 'but got {file_format}') + base_files = base_files if isinstance(base_files, + list) else [base_files] + return base_files + + @staticmethod + def _get_cfg_path(cfg_path: str, + filename: str) -> Tuple[str, Optional[str]]: + """Get the config path from the current or external package. + + Args: + cfg_path (str): Relative path of config. + filename (str): The config file being parsed. + + Returns: + Tuple[str, str or None]: Path and scope of config. If the config + is not an external config, the scope will be `None`. + """ + # Get local config path. + cfg_dir = osp.dirname(filename) + cfg_path = osp.join(cfg_dir, cfg_path) + return cfg_path, None + + @staticmethod + def _dict_to_config_dict(cfg: dict, + scope: Optional[str] = None, + has_scope=True): + """Recursively converts ``dict`` to :obj:`ConfigDict`. + + Args: + cfg (dict): Config dict. + scope (str, optional): Scope of instance. + has_scope (bool): Whether to add `_scope_` key to config dict. + + Returns: + ConfigDict: Converted dict. + """ + # Only the outer dict with key `type` should have the key `_scope_`. + if isinstance(cfg, dict): + if has_scope and 'type' in cfg: + has_scope = False + if scope is not None and cfg.get('_scope_', None) is None: + cfg._scope_ = scope # type: ignore + cfg = ConfigDict(cfg) + dict.__setattr__(cfg, 'scope', scope) + for key, value in cfg.items(): + cfg[key] = Config._dict_to_config_dict( + value, scope=scope, has_scope=has_scope) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) + for _cfg in cfg) + elif isinstance(cfg, list): + cfg = [ + Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) + for _cfg in cfg + ] + return cfg + + + @staticmethod + def _merge_a_into_b(a: dict, + b: dict, + allow_list_keys: bool = False) -> dict: + """merge dict ``a`` into dict ``b`` (non-inplace). + + Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid + in-place modifications. + + Args: + a (dict): The source dict to be merged into ``b``. + b (dict): The origin dict to be fetch keys from ``a``. + allow_list_keys (bool): If True, int string keys (e.g. '0', '1') + are allowed in source ``a`` and will replace the element of the + corresponding index in b if b is a list. Defaults to False. + + Returns: + dict: The modified dict of ``b`` using ``a``. + + Examples: + # Normally merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # Delete b first and merge a into b. + >>> Config._merge_a_into_b( + ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) + {'obj': {'a': 2}} + + # b is a list + >>> Config._merge_a_into_b( + ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) + [{'a': 2}, {'b': 2}] + """ + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f'Index {k} exceeds the length of list {b}') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, dict): + if k in b and not v.pop(DELETE_KEY, False): + allowed_types: Union[Tuple, type] = ( + dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f'{k}={v} in child config cannot inherit from ' + f'base because {k} is a dict in the child config ' + f'but is of type {type(b[k])} in base config. ' + f'You may set `{DELETE_KEY}=True` to ignore the ' + f'base config.') + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = ConfigDict(v) + else: + b[k] = v + return b + + @staticmethod + def _file2dict(filename: str) -> Tuple[dict, str, dict]: + """Transform file to variables dictionary. + Args: + filename (str): Name of config file. + Returns: + Tuple[dict, str]: Variables dictionary and text of Config. + """ + + filename = osp.abspath(osp.expanduser(filename)) + if not os.path.exists(filename): + raise FileNotFoundError(f'{filename} is not exist!') + fileExtname = osp.splitext(filename)[1] + if fileExtname not in ['.py']: + raise OSError('Only py type are supported now!') + try: + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname, delete=False) + if platform.system() == 'Windows': + temp_config_file.close() + + shutil.copyfile(filename, temp_config_file.name) + + # Handle base files + base_cfg_dict = ConfigDict() + cfg_text_list = list() + + for base_cfg_path in Config._get_base_files( + temp_config_file.name): + base_cfg_path, scope = Config._get_cfg_path( + base_cfg_path, filename) + # Generate base config + _cfg_dict, _cfg_text = Config._file2dict( + filename=base_cfg_path) + cfg_text_list.append(_cfg_text) + # Check duplicate + duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys() + if len(duplicate_keys) > 0: + raise KeyError( + 'Duplicate key is not allowed among bases. ' + f'Duplicate keys: {duplicate_keys}') + + # _dict_to_config_dict will do the following things: + # 1. Recursively converts ``dict`` to :obj:`ConfigDict`. + # 2. Set `_scope_` for the outer dict variable for the base + # config. + # 3. Set `scope` attribute for each base variable. + # Different from `_scope_`, `scope` is not a key of base + # dict, `scope` attribute will be parsed to key `_scope_` + # by function `_parse_scope` only if the base variable is + # accessed by the current config. + _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope) + base_cfg_dict.update(_cfg_dict) + + if filename.endswith('.py'): + with open(temp_config_file.name, encoding='utf-8') as f: + parsed_codes = ast.parse(f.read()) + parsed_codes = RemoveAssignFromAST(BASE_KEY).visit(parsed_codes) + codeobj = compile(parsed_codes, filename, mode='exec') + # Support load global variable in nested function of the config. + global_locals_var = {BASE_KEY: base_cfg_dict} + ori_keys = set(global_locals_var.keys()) + eval(codeobj, global_locals_var, global_locals_var) + cfg_dict = { + key: value + for key, value in global_locals_var.items() + if (key not in ori_keys and not key.startswith('__')) + } + # close temp file + for key, value in list(cfg_dict.items()): + if isinstance(value, + (types.FunctionType, types.ModuleType)): + cfg_dict.pop(key) + temp_config_file.close() + + except Exception as e: + if osp.exists(temp_config_dir): + shutil.rmtree(temp_config_dir) + raise e + + # check deprecation information + if DEPRECATION_KEY in cfg_dict: + deprecation_info = cfg_dict.pop(DEPRECATION_KEY) + warning_msg = f'The config file {filename} will be deprecated ' \ + 'in the future.' + if 'expected' in deprecation_info: + warning_msg += f' Please use {deprecation_info["expected"]} ' \ + 'instead.' + if 'reference' in deprecation_info: + warning_msg += ' More information can be found at ' \ + f'{deprecation_info["reference"]}' + warnings.warn(warning_msg, DeprecationWarning) + + cfg_text = filename + '\n' + with open(filename, encoding='utf-8') as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + cfg_dict.pop(BASE_KEY, None) + + cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = { + k: v + for k, v in cfg_dict.items() if not k.startswith('__') + } + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = '\n'.join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def fromfile(filename: Union[str, Path]) -> 'Config': + """Build a Config instance from config file. + + Args: + filename (str or Path): Name of config file. + format_python_code (bool): Whether to format Python code by yapf. + Defaults to True. + + Returns: + Config: Config instance built from config file. + """ + filename = str(filename) if isinstance(filename, Path) else filename + cfg_dict, cfg_text = Config._file2dict(filename) + + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + def __repr__(self): + return f'Config (path: {self._filename}): {self._cfg_dict.__repr__()}' + + def __len__(self): + return len(self._cfg_dict) + + def __getattr__(self, name: str) -> Any: + return getattr(self._cfg_dict, name) + + def __getitem__(self, name): + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name, value): + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self): + return iter(self._cfg_dict) + + def __getstate__( + self + ) -> Tuple[dict, Optional[str], Optional[str], dict, bool, set]: + state = (self._cfg_dict, self._filename, self._text) + return state + + def __deepcopy__(self, memo): + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + + for key, value in self.__dict__.items(): + super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) + + return other + + def __copy__(self): + cls = self.__class__ + other = cls.__new__(cls) + other.__dict__.update(self.__dict__) + super(Config, other).__setattr__('_cfg_dict', self._cfg_dict.copy()) + + return other + + copy = __copy__ + + def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str], + dict, bool, set]): + super().__setattr__('_cfg_dict', state[0]) + super().__setattr__('_filename', state[1]) + super().__setattr__('_text', state[2]) + + def to_dict(self): + """ + Convert all data in the config to a builtin ``dict``. + """ + cfg_dict = self._cfg_dict.to_dict() + return cfg_dict + +config_rules = { + 'model': { + 'choices': list(HUNYUAN_DIT_CONFIG.keys()) + }, + 'norm': { + 'choices': ["rms", "layer"] + }, + 'training_parts': { + 'choices': ['all', 'lora'] + }, + 'control_type': { + 'choices': ['canny', 'depth', 'pose'] + }, + 'predict_type': { + 'choices': list(PREDICT_TYPE) + }, + 'noise_schedule': { + 'choices': list(NOISE_SCHEDULES) + }, + 'load_key': { + 'choices': ["ema", "module", "distill", 'merge'] + }, + 'infer_mode': { + 'choices': ["fa", "torch", "trt"] + }, + 'sampler': { + 'choices': SAMPLER_FACTORY + }, + 'lang': { + 'choices': ["zh", "en"] + }, + 'rope_img': { + 'choices': ['extend', 'base512', 'base1024'] + }, + 'ema_dtype': { + 'choices': ['fp16', 'fp32', 'none'] + }, + 'remote_device': { + 'choices': ['none', 'cpu', 'nvme'] + } +} + +def get_args(default_args=None): + + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str) + parser.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher.') + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args(default_args) + + config = Config.fromfile(args.config) + # Check config values + for key, value in config.items(): + if key in config_rules: + rule = config_rules[key] + if 'choices' in rule and value not in rule['choices']: + raise ValueError(f"Invalid value '{value}' for '{key}'. Choices are: {rule['choices']}") + # Merge parsing argement + for key, value in args.__dict__.items(): + config[key] = value + + return config + +if __name__ == '__main__': + print(get_args()) diff --git a/hydit/configs/README.md b/hydit/configs/README.md new file mode 100644 index 0000000..9bd069d --- /dev/null +++ b/hydit/configs/README.md @@ -0,0 +1,32 @@ +# 优化混元DiT模型启动配置流程 + +## 配置目录结构 + +参考[MMEngine/config](https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/config.py)的配置风格和代码优化了混元DiT模型的启动配置流程,将配置参数按照数据、模型和启动流程划分,使用py文件配置模型参数;在需要新增配置文件时,可引用默认配置 + +新增`hydit/configs`目录用于存储启动配置文件,目录结构如下: +```bash +- configs + - base # 默认配置 + - dataset # 数据集配置 + - model # 模型配置 + - schedule # 启动配置 + - train # 基于默认配置文件的训练配置文件 +``` + +## 启动流程 + +在加载配置时,为了保留原有的代码结构,新增配置加载文件`hydit/config_engine.py`,在`train_deepspeed.py`中仅修改了函数`get_args`的引用模块 + +```python +# 修改前 from hydit.config import get_args +from hydit.config_engine import get_args +``` + +由于全参数训练和仅训练Lora都使用的deepspeed,所以新增`train_deepspeed.sh`脚本启动训练,启动命令如下: + +```bash +PYTHONPATH=./ sh hydit/train_deepspeed.sh --config hydit/configs/train/train_lora_dit_g2_1024p_single.py +``` + +其中,`config`参数传递的为训练配置文件相对路径 \ No newline at end of file diff --git a/hydit/configs/base/model/controlnet_canny.py b/hydit/configs/base/model/controlnet_canny.py new file mode 100644 index 0000000..31bd97b --- /dev/null +++ b/hydit/configs/base/model/controlnet_canny.py @@ -0,0 +1,8 @@ +# ========================= +# Controlnet Config +# ========================= + +control_type = 'canny' # Controlnet condition type, choices=['canny', 'depth', 'pose'] +control_weight = '1.0' # Controlnet weight, You can use a float to specify the weight for all layers, or use a list to separately specify the weight for each layer, for example, '[1.0 * (0.825 ** float(19 _ i)) for i in range(19)] +# Inference condition image path +condition_image_path = None diff --git a/hydit/configs/base/model/diffusion_v_pred.py b/hydit/configs/base/model/diffusion_v_pred.py new file mode 100644 index 0000000..ccf6075 --- /dev/null +++ b/hydit/configs/base/model/diffusion_v_pred.py @@ -0,0 +1,13 @@ +# ========================= +# Diffusion Config +# ========================= + +learn_sigma = True # Learn extra channels for sigma. +predict_type = 'v_prediction' # Diffusion predict type, choices=list(PREDICT_TYPE) +noise_schedule = 'scaled_linear'# Noise schedule, choices=list(NOISE_SCHEDULES) +beta_start = 0.00085 # Beta start value +beta_end = 0.02 # Beta end value +sigma_small = False +mse_loss_weight_type = 'constant' # Min_SNR_gamma. Can be constant or min_snr_ where gamma is a integer. 5 is recommended in the paper. +model_var_type = None # Specify the model variable type. +noise_offset = 0.0 # Add extra noise to the input image. \ No newline at end of file diff --git a/hydit/configs/base/model/dit_g2_1024_p.py b/hydit/configs/base/model/dit_g2_1024_p.py new file mode 100644 index 0000000..ede5503 --- /dev/null +++ b/hydit/configs/base/model/dit_g2_1024_p.py @@ -0,0 +1,12 @@ +# ========================= +# HunYuan_DiT Config +# ========================= + +model = 'DiT-g/2' # choices=list(HUNYUAN_DIT_CONFIG.keys()) +image_size = (1024, 1024) # Image size (h, w) +qk_norm = True # Query Key normalization. http://arxiv.org/abs/2302.05442 for details. +norm = 'layer' # Normalization layer type, choices=["rms", "laryer"] +text_states_dim = 1024 # Hidden size of CLIP text encoder +text_len = 77 # Token length of CLIP text encoder output +text_states_dim_t5 = 2048 # Hidden size of CLIP text encoder +text_len_t5 = 256 # Token length of T5 text encoder output diff --git a/hydit/configs/base/model/lora_r64.py b/hydit/configs/base/model/lora_r64.py new file mode 100644 index 0000000..5025026 --- /dev/null +++ b/hydit/configs/base/model/lora_r64.py @@ -0,0 +1,8 @@ +# ========================= +# LoRA Config +# ========================= + +rank = 64 # Rank of LoRA +lora_ckpt = None # LoRA checkpoint +target_modules = ['Wqkv', 'q_proj', 'kv_proj', 'out_proj'] # Target modules for LoRA fine tune +output_merge_path = None # Output path for merged model \ No newline at end of file diff --git a/hydit/configs/base/schedule/default.py b/hydit/configs/base/schedule/default.py new file mode 100644 index 0000000..caa71b0 --- /dev/null +++ b/hydit/configs/base/schedule/default.py @@ -0,0 +1,60 @@ +task_flag = '' +training_parts = 'all' # Training parts, choices=['all', 'lora'] + +# General Setting +seed = 42 # A seed for all the prompts +batch_size = 1 # Per GPU batch size +use_fp16 = True # Use FP16 precision +extra_fp16 = False # Use extra fp16 for vae and text_encoder +lr = 1e-4 +epochs = 100 +max_training_steps = 10_000_000 +gc_interval = 40 # To address the memory bottleneck encountered during the preprocessing of the dataset, memory fragments are reclaimed here by invoking the gc.collect() function. +log_every = 100 +ckpt_every = 100_000 # Create a ckpt every a few steps. +ckpt_latest_every = 10_000 # Create a ckpt named `latest.pt` every a few steps. +ckpt_every_n_epoch = 0 # Create a ckpt every a few epochs. If 0, do not create ckpt based on epoch. Default is 0. +num_workers = 4 +global_seed = 1234 +warmup_min_lr = 1e-6 +warmup_num_steps = 0 +weight_decay = 0 # Weight_decay in optimizer +rope_img = None # Extend or interpolate the positional embedding of the image, choices=['extend', 'base512', 'base1024'] +rope_real = False # Use real part and imaginary part separately for RoPE. + +# Classifier_free +uncond_p = 0.2 # The probability of dropping training text used for CLIP feature extraction +uncond_p_t5 = 0.2 # The probability of dropping training text used for mT5 feature extraction + +# Directory +results_dir = 'results' +resume = False +resume_module_root = None # Resume model states. +resume_ema_root = None # Resume ema states. +strict = True # Strict loading of checkpoint + +# Additional condition +random_shrink_size_cond = False # Randomly shrink the original size condition. +merge_src_cond = False # Merge the source condition into a single value. + +# EMA Model +use_ema = False # Use EMA model +ema_dtype = 'none' # EMA data type. If none, use the same data type as the model, choices=['fp16', 'fp32', 'none'] +ema_decay = None # EMA decay rate. If None, use the default value of the model. +ema_warmup = False # EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay. +ema_warmup_power = None # EMA power. If None, use the default value of the model. +ema_reset_decay = False # Reset EMA decay to 0 and restart increasing the EMA decay. Only works when ema_warmup is enabled. + +# Acceleration +use_flash_attn = False # During training, flash attention is used to accelerate training. +use_zero_stage = 1 # Use AngelPTM zero stage. Support 2 and 3 +grad_accu_steps = 1 # Gradient accumulation steps. +gradient_checkpointing = False # Use gradient checkpointing. +cpu_offloading = False # Use cpu offloading for parameters and optimizer states. +save_optimizer_state = False # Save optimizer state in the checkpoint. + +# DeepSpeed +local_rank = None +deepspeed_optimizer = False # Switching to the optimizers in DeepSpeed +remote_device = 'none' # Remote device for ZeRO_3 initialized parameters, choices=['none', 'cpu', 'nvme'] +zero_stage = 1 \ No newline at end of file diff --git a/hydit/configs/base/schedule/inference.py b/hydit/configs/base/schedule/inference.py new file mode 100644 index 0000000..0d6fac1 --- /dev/null +++ b/hydit/configs/base/schedule/inference.py @@ -0,0 +1,28 @@ +# Basic Setting +prompt = '一只小猫' # The prompt for generating images. +model_root = 'ckpts' # Root path of all the models, including t2i model and dialoggen model. +dit_weight = None # Path to the HunYuan_DiT model. If None, search the model in the args.model_root. 1. If it is a file, load the model directly. In this case, the __load_key is ignored. 2. If it is a directory, search the model in the directory. Support two types of models: 1) named `pytorch_model_*.pt`, where * is specified by the __load_key. 2) named `*_model_states.pt`, where * can be `mp_rank_00`. *_model_states.pt contains both 'module' and 'ema' weights. Therefore, you still use __load_key to specify the weights to load. By default, load 'ema' weights. +controlnet_weight = None # Path to the HunYuan_DiT controlnet model. If None, search the model in the args.model_root. 1. If it is a directory, search the model in the directory. 2. If it is a file, load the model directly. In this case, the __load_key is ignored. + +# Model setting +load_key = 'ema' # Load model key for HunYuanDiT checkpoint, choices=["ema", "module", "distill", 'merge'] +use_style_cond = False # Use style condition in hydit. Only for hydit version <= 1.1" +size_cond = None # Size condition used in sampling. 2 values are required for height and width. If a single value is provided, the image will be treated to (value, value). Recommended values are [1024, 1024]. Only for hydit version <= 1.1 +target_ratios = None # Target ratios for multi_resolution training. +cfg_scale = 6.0 # Guidance scale for classifier_free. +negative = None # Negative prompt. + +# Acceleration +infer_mode = 'fa' # Inference mode, choices=["fa", "torch", "trt"], default="fa" +onnx_workdir = 'onnx_model' # Path to save ONNX model + +# Sampling +sampler = 'ddpm' # Diffusion sampler, choices=SAMPLER_FACTORY +infer_steps = 100 # Inference steps + +# Prompt enhancement +enhance = True # Enhance prompt with mllm. +load_4bit = False # Load DialogGen model with 4bit quantization. + +# App +lang='zh' # Language, choices=["zh", "en"] \ No newline at end of file diff --git a/hydit/configs/base/schedule/train_full.py b/hydit/configs/base/schedule/train_full.py new file mode 100644 index 0000000..caa71b0 --- /dev/null +++ b/hydit/configs/base/schedule/train_full.py @@ -0,0 +1,60 @@ +task_flag = '' +training_parts = 'all' # Training parts, choices=['all', 'lora'] + +# General Setting +seed = 42 # A seed for all the prompts +batch_size = 1 # Per GPU batch size +use_fp16 = True # Use FP16 precision +extra_fp16 = False # Use extra fp16 for vae and text_encoder +lr = 1e-4 +epochs = 100 +max_training_steps = 10_000_000 +gc_interval = 40 # To address the memory bottleneck encountered during the preprocessing of the dataset, memory fragments are reclaimed here by invoking the gc.collect() function. +log_every = 100 +ckpt_every = 100_000 # Create a ckpt every a few steps. +ckpt_latest_every = 10_000 # Create a ckpt named `latest.pt` every a few steps. +ckpt_every_n_epoch = 0 # Create a ckpt every a few epochs. If 0, do not create ckpt based on epoch. Default is 0. +num_workers = 4 +global_seed = 1234 +warmup_min_lr = 1e-6 +warmup_num_steps = 0 +weight_decay = 0 # Weight_decay in optimizer +rope_img = None # Extend or interpolate the positional embedding of the image, choices=['extend', 'base512', 'base1024'] +rope_real = False # Use real part and imaginary part separately for RoPE. + +# Classifier_free +uncond_p = 0.2 # The probability of dropping training text used for CLIP feature extraction +uncond_p_t5 = 0.2 # The probability of dropping training text used for mT5 feature extraction + +# Directory +results_dir = 'results' +resume = False +resume_module_root = None # Resume model states. +resume_ema_root = None # Resume ema states. +strict = True # Strict loading of checkpoint + +# Additional condition +random_shrink_size_cond = False # Randomly shrink the original size condition. +merge_src_cond = False # Merge the source condition into a single value. + +# EMA Model +use_ema = False # Use EMA model +ema_dtype = 'none' # EMA data type. If none, use the same data type as the model, choices=['fp16', 'fp32', 'none'] +ema_decay = None # EMA decay rate. If None, use the default value of the model. +ema_warmup = False # EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay. +ema_warmup_power = None # EMA power. If None, use the default value of the model. +ema_reset_decay = False # Reset EMA decay to 0 and restart increasing the EMA decay. Only works when ema_warmup is enabled. + +# Acceleration +use_flash_attn = False # During training, flash attention is used to accelerate training. +use_zero_stage = 1 # Use AngelPTM zero stage. Support 2 and 3 +grad_accu_steps = 1 # Gradient accumulation steps. +gradient_checkpointing = False # Use gradient checkpointing. +cpu_offloading = False # Use cpu offloading for parameters and optimizer states. +save_optimizer_state = False # Save optimizer state in the checkpoint. + +# DeepSpeed +local_rank = None +deepspeed_optimizer = False # Switching to the optimizers in DeepSpeed +remote_device = 'none' # Remote device for ZeRO_3 initialized parameters, choices=['none', 'cpu', 'nvme'] +zero_stage = 1 \ No newline at end of file diff --git a/hydit/configs/base/schedule/train_lora.py b/hydit/configs/base/schedule/train_lora.py new file mode 100644 index 0000000..c079017 --- /dev/null +++ b/hydit/configs/base/schedule/train_lora.py @@ -0,0 +1,60 @@ +task_flag = '' +training_parts = 'lora' # Training parts, choices=['all', 'lora'] + +# General Setting +seed = 42 # A seed for all the prompts +batch_size = 1 # Per GPU batch size +use_fp16 = True # Use FP16 precision +extra_fp16 = False # Use extra fp16 for vae and text_encoder +lr = 1e-4 +epochs = 100 +max_training_steps = 10_000_000 +gc_interval = 40 # To address the memory bottleneck encountered during the preprocessing of the dataset, memory fragments are reclaimed here by invoking the gc.collect() function. +log_every = 100 +ckpt_every = 100_000 # Create a ckpt every a few steps. +ckpt_latest_every = 10_000 # Create a ckpt named `latest.pt` every a few steps. +ckpt_every_n_epoch = 0 # Create a ckpt every a few epochs. If 0, do not create ckpt based on epoch. Default is 0. +num_workers = 4 +global_seed = 1234 +warmup_min_lr = 1e-6 +warmup_num_steps = 0 +weight_decay = 0 # Weight_decay in optimizer +rope_img = None # Extend or interpolate the positional embedding of the image, choices=['extend', 'base512', 'base1024'] +rope_real = False # Use real part and imaginary part separately for RoPE. + +# Classifier_free +uncond_p = 0.2 # The probability of dropping training text used for CLIP feature extraction +uncond_p_t5 = 0.2 # The probability of dropping training text used for mT5 feature extraction + +# Directory +results_dir = 'results' +resume = False +resume_module_root = None # Resume model states. +resume_ema_root = None # Resume ema states. +strict = True # Strict loading of checkpoint + +# Additional condition +random_shrink_size_cond = False # Randomly shrink the original size condition. +merge_src_cond = False # Merge the source condition into a single value. + +# EMA Model +use_ema = False # Use EMA model +ema_dtype = 'none' # EMA data type. If none, use the same data type as the model, choices=['fp16', 'fp32', 'none'] +ema_decay = None # EMA decay rate. If None, use the default value of the model. +ema_warmup = False # EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay. +ema_warmup_power = None # EMA power. If None, use the default value of the model. +ema_reset_decay = False # Reset EMA decay to 0 and restart increasing the EMA decay. Only works when ema_warmup is enabled. + +# Acceleration +use_flash_attn = False # During training, flash attention is used to accelerate training. +use_zero_stage = 1 # Use AngelPTM zero stage. Support 2 and 3 +grad_accu_steps = 1 # Gradient accumulation steps. +gradient_checkpointing = False # Use gradient checkpointing. +cpu_offloading = False # Use cpu offloading for parameters and optimizer states. +save_optimizer_state = False # Save optimizer state in the checkpoint. + +# DeepSpeed +local_rank = None +deepspeed_optimizer = False # Switching to the optimizers in DeepSpeed +remote_device = 'none' # Remote device for ZeRO_3 initialized parameters, choices=['none', 'cpu', 'nvme'] +zero_stage = 1 \ No newline at end of file diff --git a/hydit/configs/train/train_full_dit_g2_1024p_multi.py b/hydit/configs/train/train_full_dit_g2_1024p_multi.py new file mode 100644 index 0000000..d8d7141 --- /dev/null +++ b/hydit/configs/train/train_full_dit_g2_1024p_multi.py @@ -0,0 +1,59 @@ +_base_ = [ + '../base/dataset/multi_porcelain.py', + '../base/model/dit_g2_1024_p.py', + '../base/model/lora_r64.py', + '../base/model/controlnet_canny.py', + '../base/model/diffusion_v_pred.py', + '../base/schedule/train_full.py', + '../base/schedule/inference.py' +] + +task_flag = 'dit_g2_full_1024p' + +batch_size = 1 # training batch size +use_fp16 = True +extra_fp16 = True + +lr = 0.0001 # learning rate +epochs = 8 # total training epochs + +log_every = 10 + +ckpt_every = 9999999 # create a ckpt every a few steps. +ckpt_latest_every = 9999999 # create a ckpt named `latest.pt` every a few steps. +ckpt_every_n_epoch=2 # create a ckpt every a few epochs. + +global_seed = 999 + +warmup_num_steps=0 # warm_up steps + +rope_img = 'base512' +rope_real = True + +uncond_p = 0 +uncond_p_t5 = 0 + +results_dir = './log_EXP' # save root for results +resume = True +# checkpoint root for model resume +resume_module_root = './ckpts/t2i/model/pytorch_model_distill.pt' +# checkpoint root for ema resume +resume_ema_root = './ckpts/t2i/model/pytorch_model_ema.pt' + +use_zero_stage = 2 +grad_accu_steps=1 # gradient accumulation +cpu_offloading = True +gradient_checkpointing = True + +# model + +predict_type = 'v_prediction' +noise_schedule = 'scaled_linear' +beta_start = 0.00085 +beta_end = 0.018 + +# dataset + +random_flip = True + + diff --git a/hydit/configs/train/train_full_dit_g2_1024p_single.py b/hydit/configs/train/train_full_dit_g2_1024p_single.py new file mode 100644 index 0000000..074e8a2 --- /dev/null +++ b/hydit/configs/train/train_full_dit_g2_1024p_single.py @@ -0,0 +1,59 @@ +_base_ = [ + '../base/dataset/single_porcelain.py', + '../base/model/dit_g2_1024_p.py', + '../base/model/lora_r64.py', + '../base/model/controlnet_canny.py', + '../base/model/diffusion_v_pred.py', + '../base/schedule/train_full.py', + '../base/schedule/inference.py' +] + +task_flag = 'dit_g2_full_1024p' + +batch_size = 1 # training batch size +use_fp16 = True +extra_fp16 = True + +lr = 0.0001 # learning rate +epochs = 8 # total training epochs + +log_every = 10 + +ckpt_every = 9999999 # create a ckpt every a few steps. +ckpt_latest_every = 9999999 # create a ckpt named `latest.pt` every a few steps. +ckpt_every_n_epoch=2 # create a ckpt every a few epochs. + +global_seed = 999 + +warmup_num_steps=0 # warm_up steps + +rope_img = 'base512' +rope_real = True + +uncond_p = 0 +uncond_p_t5 = 0 + +results_dir = './log_EXP' # save root for results +resume = True +# checkpoint root for model resume +resume_module_root = './ckpts/t2i/model/pytorch_model_distill.pt' +# checkpoint root for ema resume +resume_ema_root = './ckpts/t2i/model/pytorch_model_ema.pt' + +use_zero_stage = 2 +grad_accu_steps=1 # gradient accumulation +cpu_offloading = True +gradient_checkpointing = True + +# model + +predict_type = 'v_prediction' +noise_schedule = 'scaled_linear' +beta_start = 0.00085 +beta_end = 0.018 + +# dataset + +random_flip = True + + diff --git a/hydit/configs/train/train_full_dit_g2_1024p_single_no_flash.py b/hydit/configs/train/train_full_dit_g2_1024p_single_no_flash.py new file mode 100644 index 0000000..9813283 --- /dev/null +++ b/hydit/configs/train/train_full_dit_g2_1024p_single_no_flash.py @@ -0,0 +1,59 @@ +_base_ = [ + '../base/dataset/single_porcelain.py', + '../base/model/dit_g2_1024_p.py', + '../base/model/lora_r64.py', + '../base/model/controlnet_canny.py', + '../base/model/diffusion_v_pred.py', + '../base/schedule/train_full.py', + '../base/schedule/inference.py' +] + +task_flag = 'dit_g2_full_1024p' + +batch_size = 1 # training batch size +use_fp16 = True +extra_fp16 = True + +lr = 0.0001 # learning rate +epochs = 8 # total training epochs + +log_every = 10 + +ckpt_every = 9999999 # create a ckpt every a few steps. +ckpt_latest_every = 9999999 # create a ckpt named `latest.pt` every a few steps. +ckpt_every_n_epoch=2 # create a ckpt every a few epochs. + +global_seed = 999 + +warmup_num_steps=0 # warm_up steps + +rope_img = 'base512' +rope_real = True + +uncond_p = 0 +uncond_p_t5 = 0 + +results_dir = './log_EXP' # save root for results +resume = True +# checkpoint root for model resume +resume_module_root = './ckpts/t2i/model/pytorch_model_module.pt' +# checkpoint root for ema resume +resume_ema_root = './ckpts/t2i/model/pytorch_model_ema.pt' + +use_zero_stage = 2 +grad_accu_steps=1 # gradient accumulation +cpu_offloading = True +gradient_checkpointing = True + +# model + +predict_type = 'v_prediction' +noise_schedule = 'scaled_linear' +beta_start = 0.00085 +beta_end = 0.018 + +# dataset + +random_flip = True + + diff --git a/hydit/configs/train/train_lora_dit_g2_1024p_single.py b/hydit/configs/train/train_lora_dit_g2_1024p_single.py new file mode 100644 index 0000000..c08862a --- /dev/null +++ b/hydit/configs/train/train_lora_dit_g2_1024p_single.py @@ -0,0 +1,57 @@ +_base_ = [ + '../base/dataset/single_porcelain.py', + '../base/model/dit_g2_1024_p.py', + '../base/model/lora_r64.py', + '../base/model/controlnet_canny.py', + '../base/model/diffusion_v_pred.py', + '../base/schedule/train_lora.py', + '../base/schedule/inference.py' +] + +task_flag = 'lora_porcelain_ema_rank64' + +batch_size = 1 # training batch size +use_fp16 = True +extra_fp16 = True + +lr = 0.0001 # learning rate +epochs = 8 # total training epochs +max_training_steps=2000 # Maximum training iteration steps +log_every = 10 + +ckpt_every = 100 # create a ckpt every a few steps. +ckpt_latest_every = 2000 # create a ckpt named `latest.pt` every a few steps. +ckpt_every_n_epoch=2 # create a ckpt every a few epochs. + +global_seed = 999 + +warmup_num_steps=0 # warm_up steps + +rope_img = 'base512' +rope_real = True + +uncond_p = 0 +uncond_p_t5 = 0 + +results_dir = './log_EXP' # save root for results +resume = True +# checkpoint root for model resume +resume_module_root = './ckpts/t2i/model/pytorch_model_distill.pt' +# checkpoint root for ema resume +resume_ema_root = './ckpts/t2i/model/pytorch_model_ema.pt' + +use_zero_stage = 2 +grad_accu_steps=2 # gradient accumulation +cpu_offloading = True +gradient_checkpointing = True + +# model + +predict_type = 'v_prediction' +noise_schedule = 'scaled_linear' +beta_start = 0.00085 +beta_end = 0.018 + +# dataset + +random_flip = True diff --git a/hydit/modules/models.py b/hydit/modules/models.py index 10b23e1..304f021 100644 --- a/hydit/modules/models.py +++ b/hydit/modules/models.py @@ -212,6 +212,8 @@ def __init__(self, use_flash_attn = args.infer_mode == 'fa' or args.use_flash_attn if use_flash_attn: log_fn(f" Enable Flash Attention.") + else: + log_fn(f" Disable Flash Attention.") qk_norm = args.qk_norm # See http://arxiv.org/abs/2302.05442 for details. self.mlp_t5 = nn.Sequential( diff --git a/hydit/train_deepspeed.py b/hydit/train_deepspeed.py index 91736c6..d7c5d01 100644 --- a/hydit/train_deepspeed.py +++ b/hydit/train_deepspeed.py @@ -18,7 +18,8 @@ from IndexKits.index_kits import ResolutionGroup from IndexKits.index_kits.sampler import DistributedSamplerWithStartIndex, BlockDistributedSampler -from hydit.config import get_args +# from hydit.config import get_args +from hydit.config_engine import get_args from hydit.constants import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER from hydit.data_loader.arrow_load_stream import TextImageArrowStream from hydit.diffusion import create_diffusion @@ -31,6 +32,7 @@ from hydit.modules.posemb_layers import init_image_posemb from hydit.utils.tools import create_exp_folder, model_resume, get_trainable_params +from apex.normalization import FusedRMSNorm def deepspeed_initialize(args, logger, model, opt, deepspeed_config): logger.info(f"Initialize deepspeed...") @@ -193,6 +195,7 @@ def main(args): torch.cuda.set_device(device) print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.") deepspeed_config = deepspeed_config_from_args(args, global_batch_size) + print(deepspeed_config) # Setup an experiment folder experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank) diff --git a/hydit/train_deepspeed.sh b/hydit/train_deepspeed.sh new file mode 100644 index 0000000..ae2120c --- /dev/null +++ b/hydit/train_deepspeed.sh @@ -0,0 +1 @@ +deepspeed hydit/train_deepspeed.py --deepspeed "$@" \ No newline at end of file