diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 7de4bae..417dc97 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -740,6 +740,7 @@ def train( weight_decay_ti: float = 0.00, weight_decay_lora: float = 0.001, use_8bit_adam: bool = False, + use_autocast_train_inversion: bool = False, device="cuda:0", extra_args: Optional[dict] = None, log_wandb: bool = False, @@ -924,7 +925,7 @@ def train( wandb_log_prompt_cnt=wandb_log_prompt_cnt, class_token=class_token, train_inpainting=train_inpainting, - mixed_precision=False, + mixed_precision=use_autocast_train_inversion, tokenizer=tokenizer, clip_ti_decay=clip_ti_decay, )