Skip to content

Commit

Permalink
fix(track): update config
Browse files Browse the repository at this point in the history
  • Loading branch information
caopulan committed Jun 24, 2023
1 parent f9a95a7 commit 55c2cf3
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion unidiffusion/pipelines/unidiffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class UniDiffusionPipeline:
models = None
current_iter = 0
evaluators = []
config = None

def __init__(self, cfg, training):
self.cfg = cfg
Expand All @@ -67,6 +68,7 @@ def default_setup(self):
log_tracker = [platform for platform in ['wandb', 'tensorboard', 'comet_ml'] if self.cfg.train[platform]['enabled']]
if len(log_tracker) >= 1:
self.cfg.accelerator.log_with = log_tracker[0] # todo: support multiple loggers
self.config = OmegaConf.to_container(self.cfg, resolve=True)
self.accelerator = instantiate(self.cfg.accelerator)

if self.accelerator.is_main_process:
Expand Down Expand Up @@ -247,7 +249,7 @@ def prepare_training(self):
os.path.split(output_dir)[-2],
}
init_kwargs['wandb'] = wandb_kwargs
self.accelerator.init_trackers(self.cfg.train.project, config=vars(self.cfg), init_kwargs=init_kwargs)
self.accelerator.init_trackers(self.cfg.train.project, config=self.config, init_kwargs=init_kwargs)

def prepare_inference(self):
self.proxy_model.set_requires_grad(False)
Expand Down

0 comments on commit 55c2cf3

Please sign in to comment.