Skip to content

Commit

Permalink
Fix device issue in load_file, reduce vram usage
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Mar 31, 2023
1 parent ea1cf4a commit 8cecc67
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

## Change History

- 31 Mar. 2023, 2023/3/31:
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
- `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。
- `train_network.py``.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354)
- 30 Mar. 2023, 2023/3/30:
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
Expand Down
2 changes: 1 addition & 1 deletion library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device):

if is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, device)
state_dict = load_file(ckpt_path) # , device) # may causes error
else:
checkpoint = torch.load(ckpt_path, map_location=device)
if "state_dict" in checkpoint:
Expand Down
25 changes: 15 additions & 10 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight


# TODO 他のスクリプトと共通化する
Expand Down Expand Up @@ -131,16 +131,21 @@ def train(args):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device)

text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()

# work on low-ram device
# NOTE: this may not be necessary because we already load them on gpu
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)

# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
Expand Down Expand Up @@ -197,7 +202,7 @@ def train(args):
# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
Expand Down Expand Up @@ -564,9 +569,9 @@ def train(args):

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down

0 comments on commit 8cecc67

Please sign in to comment.