Skip to content

Commit

Permalink
fix(dist): fix distributed & amp training
Browse files Browse the repository at this point in the history
  • Loading branch information
caopulan committed Jun 23, 2023
1 parent 9007fd1 commit c2cbbee
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions configs/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

inference = {
'inference_iter': 5000,
'batch_size': 1, # not used
# 'batch_size': 1, # not used
'prompts': None, # using dataset prompt if None
'total_num': 10,
'scheduler': 'DPMSolverMultistepScheduler',
Expand All @@ -43,7 +43,7 @@
evaluation = {
'evaluation_iter': 10000,
'total_num': 1000, # synthesis images num
'batch_size': 4,
# 'batch_size': 1, # not used
'prompts': None, # using dataset prompt if None
'scheduler': 'DPMSolverMultistepScheduler',
'num_inference_steps': 25,
Expand Down
2 changes: 1 addition & 1 deletion configs/train/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
dataset = get_config("common/data/huggingface_dataset.py").dataset

train.output_dir = 'experiments/pokemon/lora'
dataset.path = "lambdalabs/pokemon-blip-captions",
dataset.path = "lambdalabs/pokemon-blip-captions"
train.pretrained_model_name_or_path = 'runwayml/stable-diffusion-v1-5'

unet.training_args = {
Expand Down
7 changes: 2 additions & 5 deletions unidiffusion/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,10 @@ def __init__(self, org_module, org_name, rank=4, scale=1.0):
self.apply_to()

def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype

down_hidden_states = self.down(hidden_states.to(dtype))
down_hidden_states = self.down(hidden_states)
up_hidden_states = self.up(down_hidden_states)

return up_hidden_states.to(orig_dtype) * self.scale + self.org_forward(hidden_states)
return up_hidden_states * self.scale + self.org_forward(hidden_states)


class LoRAConvLayer(BaseLoRAModule):
Expand Down
7 changes: 5 additions & 2 deletions unidiffusion/pipelines/unidiffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __init__(self, cfg, training):
def default_setup(self):
# setup log tracker and accelerator
log_tracker = [platform for platform in ['wandb', 'tensorboard', 'comet_ml'] if self.cfg.train[platform]['enabled']]
self.cfg.accelerator.log_with = log_tracker[0] # todo: support multiple loggers
if len(log_tracker) >= 1:
self.cfg.accelerator.log_with = log_tracker[0] # todo: support multiple loggers
self.accelerator = instantiate(self.cfg.accelerator)

if self.accelerator.is_main_process:
Expand Down Expand Up @@ -169,6 +170,7 @@ def build_optimizer(self):
self.logger.info("Building optimizer ... ")
self.cfg.optimizer.params = self.proxy_model.params_group
self.optimizer = instantiate(OmegaConf.to_container(self.cfg.optimizer), convert=False) # not convert list to ListConfig
self.optimizer = self.accelerator.prepare(self.optimizer)

# print num of trainable parameters
num_params = sum([p.numel() for params_group in self.optimizer.param_groups for p in params_group['params']])
Expand All @@ -181,6 +183,7 @@ def build_scheduler(self):
self.cfg.lr_scheduler.num_training_steps = self.cfg.train.max_iter * self.cfg.train.gradient_accumulation_iter

self.lr_scheduler = instantiate(self.cfg.lr_scheduler)
self.lr_scheduler = self.accelerator.prepare(self.lr_scheduler)

def build_evaluator(self):
self.logger.info("Building evaluator ... ")
Expand Down Expand Up @@ -290,7 +293,7 @@ def train(self):
optimizer, lr_scheduler = self.optimizer, self.lr_scheduler
while self.current_iter < self.cfg.train.max_iter:
batch = next(iter(self.dataloader))
with accelerator.accumulate(unet):
with accelerator.accumulate(self.proxy_model):
# ------------------------------------------------------------
# 1. Inference
# ------------------------------------------------------------
Expand Down

0 comments on commit c2cbbee

Please sign in to comment.