-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathunloadModel.py
96 lines (83 loc) · 2.78 KB
/
unloadModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import comfy.model_management as model_management
import gc
import torch
import time
# Note: This doesn't work with reroute for some reason?
class AnyType(str):
def __ne__(self, __value: object) -> bool:
return False
any = AnyType("*")
class UnloadModelNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {"value": (any, )}, # For passthrough
"optional": {"model": (any, )},
}
@classmethod
def VALIDATE_INPUTS(s, **kwargs):
return True
RETURN_TYPES = (any, )
FUNCTION = "route"
CATEGORY = "Unload Model"
def route(self, **kwargs):
print("Unload Model:")
loaded_models = model_management.loaded_models()
if kwargs.get("model") in loaded_models:
print(" - Model found in memory, unloading...")
loaded_models.remove(kwargs.get("model"))
else:
# Just delete it, I need the VRAM!
model = kwargs.get("model")
if type(model) == dict:
keys = [(key, type(value).__name__) for key, value in model.items()]
for key, model_type in keys:
if key == 'model':
print(f"Unloading model of type {model_type}")
del model[key]
# Emptying the cache after this should free the memory.
model_management.free_memory(1e30, model_management.get_torch_device(), loaded_models)
model_management.soft_empty_cache(True)
try:
print(" - Clearing Cache...")
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except:
print(" - Unable to clear cache")
#time.sleep(2) # why?
return (list(kwargs.values()))
class UnloadAllModelsNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {"value": (any, )},
}
@classmethod
def VALIDATE_INPUTS(s, **kwargs):
return True
RETURN_TYPES = (any, )
FUNCTION = "route"
CATEGORY = "Unload Model"
def route(self, **kwargs):
print("Unload Model:")
print(" - Unloading all models...")
model_management.unload_all_models()
model_management.soft_empty_cache(True)
try:
print(" - Clearing Cache...")
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except:
print(" - Unable to clear cache")
#time.sleep(2) # why?
return (list(kwargs.values()))
NODE_CLASS_MAPPINGS = {
"UnloadModel": UnloadModelNode,
"UnloadAllModels": UnloadAllModelsNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UnloadModel": "Unload Model",
"UnloadAllModels": "Unload All Models",
}