From c2cbbeefd779e6ae75f9e7c43d4546c53f8c355b Mon Sep 17 00:00:00 2001 From: sariel <2788787973@qq.com> Date: Sat, 24 Jun 2023 01:47:35 +0800 Subject: [PATCH] fix(dist): fix distributed & amp training --- configs/common/common.py | 4 ++-- configs/train/lora.py | 2 +- unidiffusion/peft/lora.py | 7 ++----- unidiffusion/pipelines/unidiffusion_pipeline.py | 7 +++++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/configs/common/common.py b/configs/common/common.py index 48f33f2..5a311bb 100644 --- a/configs/common/common.py +++ b/configs/common/common.py @@ -32,7 +32,7 @@ inference = { 'inference_iter': 5000, - 'batch_size': 1, # not used + # 'batch_size': 1, # not used 'prompts': None, # using dataset prompt if None 'total_num': 10, 'scheduler': 'DPMSolverMultistepScheduler', @@ -43,7 +43,7 @@ evaluation = { 'evaluation_iter': 10000, 'total_num': 1000, # synthesis images num - 'batch_size': 4, + # 'batch_size': 1, # not used 'prompts': None, # using dataset prompt if None 'scheduler': 'DPMSolverMultistepScheduler', 'num_inference_steps': 25, diff --git a/configs/train/lora.py b/configs/train/lora.py index 34d7742..8c43b76 100644 --- a/configs/train/lora.py +++ b/configs/train/lora.py @@ -4,7 +4,7 @@ dataset = get_config("common/data/huggingface_dataset.py").dataset train.output_dir = 'experiments/pokemon/lora' -dataset.path = "lambdalabs/pokemon-blip-captions", +dataset.path = "lambdalabs/pokemon-blip-captions" train.pretrained_model_name_or_path = 'runwayml/stable-diffusion-v1-5' unet.training_args = { diff --git a/unidiffusion/peft/lora.py b/unidiffusion/peft/lora.py index 391ede0..4b2e524 100644 --- a/unidiffusion/peft/lora.py +++ b/unidiffusion/peft/lora.py @@ -67,13 +67,10 @@ def __init__(self, org_module, org_name, rank=4, scale=1.0): self.apply_to() def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) + down_hidden_states = self.down(hidden_states) up_hidden_states = self.up(down_hidden_states) - return up_hidden_states.to(orig_dtype) * self.scale + self.org_forward(hidden_states) + return up_hidden_states * self.scale + self.org_forward(hidden_states) class LoRAConvLayer(BaseLoRAModule): diff --git a/unidiffusion/pipelines/unidiffusion_pipeline.py b/unidiffusion/pipelines/unidiffusion_pipeline.py index d3cf911..d947c3a 100644 --- a/unidiffusion/pipelines/unidiffusion_pipeline.py +++ b/unidiffusion/pipelines/unidiffusion_pipeline.py @@ -65,7 +65,8 @@ def __init__(self, cfg, training): def default_setup(self): # setup log tracker and accelerator log_tracker = [platform for platform in ['wandb', 'tensorboard', 'comet_ml'] if self.cfg.train[platform]['enabled']] - self.cfg.accelerator.log_with = log_tracker[0] # todo: support multiple loggers + if len(log_tracker) >= 1: + self.cfg.accelerator.log_with = log_tracker[0] # todo: support multiple loggers self.accelerator = instantiate(self.cfg.accelerator) if self.accelerator.is_main_process: @@ -169,6 +170,7 @@ def build_optimizer(self): self.logger.info("Building optimizer ... ") self.cfg.optimizer.params = self.proxy_model.params_group self.optimizer = instantiate(OmegaConf.to_container(self.cfg.optimizer), convert=False) # not convert list to ListConfig + self.optimizer = self.accelerator.prepare(self.optimizer) # print num of trainable parameters num_params = sum([p.numel() for params_group in self.optimizer.param_groups for p in params_group['params']]) @@ -181,6 +183,7 @@ def build_scheduler(self): self.cfg.lr_scheduler.num_training_steps = self.cfg.train.max_iter * self.cfg.train.gradient_accumulation_iter self.lr_scheduler = instantiate(self.cfg.lr_scheduler) + self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler) def build_evaluator(self): self.logger.info("Building evaluator ... ") @@ -290,7 +293,7 @@ def train(self): optimizer, lr_scheduler = self.optimizer, self.lr_scheduler while self.current_iter < self.cfg.train.max_iter: batch = next(iter(self.dataloader)) - with accelerator.accumulate(unet): + with accelerator.accumulate(self.proxy_model): # ------------------------------------------------------------ # 1. Inference # ------------------------------------------------------------