-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathinstall.py
116 lines (98 loc) · 4.21 KB
/
install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import sys
import launch
import platform
import os
import shutil
import site
import glob
import re
dirname = os.path.dirname(__file__)
repo_dir = os.path.join(dirname, "kohya_ss")
def prepare_environment():
torch_command = os.environ.get(
"TORCH_COMMAND",
"pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118",
)
sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git")
sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main")
requirements_file = os.environ.get("REQS_FILE", "requirements.txt")
sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install")
sys.argv, disable_strict_version = launch.extract_arg(
sys.argv, "--disable-strict-version"
)
sys.argv, skip_torch_cuda_test = launch.extract_arg(
sys.argv, "--skip-torch-cuda-test"
)
sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo")
sys.argv, update = launch.extract_arg(sys.argv, "--update")
sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers")
sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch")
xformers = "--xformers" in sys.argv
ngrok = "--ngrok" in sys.argv
if skip_install:
return
if (
reinstall_torch
or not launch.is_installed("torch")
or not launch.is_installed("torchvision")
) and not disable_strict_version:
launch.run(
f'"{launch.python}" -m {torch_command}',
"Installing torch and torchvision",
"Couldn't install torch",
)
if not skip_torch_cuda_test:
launch.run_python(
"import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'"
)
if (not launch.is_installed("xformers") or reinstall_xformers) and xformers:
launch.run_pip("install xformers --pre", "xformers")
if update and os.path.exists(repo_dir):
launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune')
launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main')
elif not os.path.exists(repo_dir):
launch.run(
f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"'
)
if not skip_checkout_repo:
launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}')
if not launch.is_installed("gradio"):
launch.run_pip("install gradio==3.16.2", "gradio")
if not launch.is_installed("pyngrok") and ngrok:
launch.run_pip("install pyngrok", "ngrok")
if platform.system() == "Linux":
if not launch.is_installed("triton"):
launch.run_pip("install triton", "triton")
if disable_strict_version:
with open(os.path.join(repo_dir, requirements_file), "r") as f:
txt = f.read()
requirements = [
re.split("==|<|>", a)[0]
for a in txt.split("\n")
if (not a.startswith("#") and a != ".")
]
requirements = " ".join(requirements)
launch.run_pip(
f'install "{requirements}" "{repo_dir}"',
"requirements for kohya sd-scripts",
)
else:
launch.run(
f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt',
desc=f"Installing requirements for kohya sd-scripts",
errdesc=f"Couldn't install requirements for kohya sd-scripts",
)
if platform.system() == "Windows":
for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")):
filename = os.path.basename(file)
for dir in site.getsitepackages():
outfile = (
os.path.join(dir, "bitsandbytes", "cuda_setup", filename)
if filename == "main.py"
else os.path.join(dir, "bitsandbytes", filename)
)
if not os.path.exists(os.path.dirname(outfile)):
continue
shutil.copy(file, outfile)
if __name__ == "__main__":
prepare_environment()