-
Notifications
You must be signed in to change notification settings - Fork 106
/
Copy pathia_threading.py
186 lines (142 loc) · 5.36 KB
/
ia_threading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gc
import inspect
import threading
from contextlib import ContextDecorator
from functools import wraps
import torch
from modules import devices, safe, shared
from modules.sd_models import load_model, reload_model_weights
backup_sd_model, backup_device, backup_ckpt_info = None, None, None
model_access_sem = threading.Semaphore(1)
def clear_cache():
gc.collect()
devices.torch_gc()
def is_sdxl_lowvram(sd_model):
return (shared.cmd_opts.lowvram or shared.cmd_opts.medvram or
getattr(shared.cmd_opts, "medvram_sdxl", False) and hasattr(sd_model, "conditioner"))
def webui_reload_model_weights(sd_model=None, info=None):
try:
reload_model_weights(sd_model=sd_model, info=info)
except Exception:
load_model(checkpoint_info=info)
def pre_offload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if (shared.sd_model is not None and not is_sdxl_lowvram(shared.sd_model) and
getattr(shared.sd_model, "device", devices.cpu) != devices.cpu):
backup_sd_model = shared.sd_model
backup_device = backup_sd_model.device
backup_sd_model.to(devices.cpu)
clear_cache()
def await_pre_offload_model_weights():
global model_access_sem
thread = threading.Thread(target=pre_offload_model_weights, args=(model_access_sem,))
thread.start()
thread.join()
def pre_reload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None and backup_ckpt_info is not None:
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
backup_ckpt_info = None
def await_pre_reload_model_weights():
global model_access_sem
thread = threading.Thread(target=pre_reload_model_weights, args=(model_access_sem,))
thread.start()
thread.join()
def backup_reload_ckpt_info(sem, info):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None:
backup_ckpt_info = shared.sd_model.sd_checkpoint_info
webui_reload_model_weights(sd_model=shared.sd_model, info=info)
def await_backup_reload_ckpt_info(info):
global model_access_sem
thread = threading.Thread(target=backup_reload_ckpt_info, args=(model_access_sem, info))
thread.start()
thread.join()
def post_reload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None and backup_ckpt_info is not None:
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
backup_ckpt_info = None
def async_post_reload_model_weights():
global model_access_sem
thread = threading.Thread(target=post_reload_model_weights, args=(model_access_sem,))
thread.start()
def acquire_release_semaphore(sem):
with sem:
pass
def await_acquire_release_semaphore():
global model_access_sem
thread = threading.Thread(target=acquire_release_semaphore, args=(model_access_sem,))
thread.start()
thread.join()
def clear_cache_decorator(func):
@wraps(func)
def yield_wrapper(*args, **kwargs):
clear_cache()
yield from func(*args, **kwargs)
clear_cache()
@wraps(func)
def wrapper(*args, **kwargs):
clear_cache()
res = func(*args, **kwargs)
clear_cache()
return res
if inspect.isgeneratorfunction(func):
return yield_wrapper
else:
return wrapper
def post_reload_decorator(func):
@wraps(func)
def yield_wrapper(*args, **kwargs):
await_acquire_release_semaphore()
yield from func(*args, **kwargs)
async_post_reload_model_weights()
@wraps(func)
def wrapper(*args, **kwargs):
await_acquire_release_semaphore()
res = func(*args, **kwargs)
async_post_reload_model_weights()
return res
if inspect.isgeneratorfunction(func):
return yield_wrapper
else:
return wrapper
def offload_reload_decorator(func):
@wraps(func)
def yield_wrapper(*args, **kwargs):
await_pre_offload_model_weights()
yield from func(*args, **kwargs)
async_post_reload_model_weights()
@wraps(func)
def wrapper(*args, **kwargs):
await_pre_offload_model_weights()
res = func(*args, **kwargs)
async_post_reload_model_weights()
return res
if inspect.isgeneratorfunction(func):
return yield_wrapper
else:
return wrapper
class torch_default_load_cd(ContextDecorator):
def __init__(self):
self.backup_load = safe.load
def __enter__(self):
self.backup_load = torch.load
torch.load = safe.unsafe_torch_load
return self
def __exit__(self, *exc):
torch.load = self.backup_load
return False