forked from bghira/SimpleTuner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
87 lines (73 loc) · 2.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import logging
# Quiet down, you.
ds_logger1 = logging.getLogger("DeepSpeed")
ds_logger2 = logging.getLogger("torch.distributed.elastic.multiprocessing.redirects")
ds_logger1.setLevel("ERROR")
ds_logger2.setLevel("ERROR")
import logging.config
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": True,
}
)
from os import environ
environ["ACCELERATE_LOG_LEVEL"] = "WARNING"
from helpers.training.trainer import Trainer
from helpers.training.state_tracker import StateTracker
from helpers import log_format
logger = logging.getLogger("SimpleTuner")
logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
if __name__ == "__main__":
trainer = None
try:
import multiprocessing
multiprocessing.set_start_method("fork")
except Exception as e:
logger.error(
"Failed to set the multiprocessing start method to 'fork'. Unexpected behaviour such as high memory overhead or poor performance may result."
f"\nError: {e}"
)
try:
trainer = Trainer(
exit_on_error=True,
)
trainer.configure_webhook()
trainer.init_noise_schedule()
trainer.init_seed()
trainer.init_huggingface_hub()
trainer.init_preprocessing_models()
trainer.init_precision(preprocessing_models_only=True)
trainer.init_data_backend()
trainer.init_validation_prompts()
trainer.init_unload_text_encoder()
trainer.init_unload_vae()
trainer.init_load_base_model()
trainer.init_controlnet_model()
trainer.init_precision()
trainer.init_freeze_models()
trainer.init_trainable_peft_adapter()
trainer.init_ema_model()
# EMA must be quantised if the base model is as well.
trainer.init_precision(ema_only=True)
trainer.move_models(destination="accelerator")
trainer.init_validations()
trainer.init_benchmark_base_model()
trainer.resume_and_prepare()
trainer.init_trackers()
trainer.train()
except KeyboardInterrupt:
if StateTracker.get_webhook_handler() is not None:
StateTracker.get_webhook_handler().send(
message="Training has been interrupted by user action (lost terminal, or ctrl+C)."
)
except Exception as e:
import traceback
if StateTracker.get_webhook_handler() is not None:
StateTracker.get_webhook_handler().send(
message=f"Training has failed. Please check the logs for more information: {e}"
)
print(e)
print(traceback.format_exc())
if trainer is not None and trainer.bf is not None:
trainer.bf.stop_fetching()