-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhf_downloader.py
73 lines (61 loc) · 2.67 KB
/
hf_downloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import huggingface_hub
from huggingface_hub import hf_hub_download
from typing import Optional
class HuggingFaceDownloader:
@classmethod
def INPUT_TYPES(cls):
# Correct path to models directory
base_models_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'models')
try:
# List only directories in the models path
model_dirs = [d for d in os.listdir(base_models_path)
if os.path.isdir(os.path.join(base_models_path, d))]
except FileNotFoundError:
# Fallback to an empty list if directory not found
model_dirs = []
return {
"required": {
"repo_id": ("STRING", {"default": ""}),
"filename": ("STRING", {"default": ""}),
"download_directory": (model_dirs or ["models"], {}),
"hf_token": ("STRING", {
"default": "",
"multiline": False,
"placeholder": "Optional HuggingFace Read Token"
}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "download_model"
CATEGORY = "Model Utilities"
OUTPUT_NODE = False
def download_model(self, repo_id: str, filename: str, download_directory: str,
hf_token: Optional[str] = None):
try:
# Construct full download path
base_models_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'models')
full_download_path = os.path.join(base_models_path, download_directory)
# Ensure the download directory exists
os.makedirs(full_download_path, exist_ok=True)
# Prepare token
token = hf_token.strip() if hf_token else None
# Download the file directly to the specified directory
downloaded_file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
token=token,
local_dir=full_download_path,
local_dir_use_symlinks=False # Ensure actual file is downloaded
)
return (downloaded_file_path,)
except Exception as e:
print(f"Error downloading model: {e}")
return ("Download failed",)
# Mapping for ComfyUI to recognize the node
NODE_CLASS_MAPPINGS = {
"HuggingFaceDownloader": HuggingFaceDownloader
}
NODE_DISPLAY_NAME_MAPPINGS = {
"HuggingFaceDownloader": "HuggingFace Model Downloader"
}