This repository has been archived by the owner on Dec 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup_demo.py
107 lines (91 loc) · 3.31 KB
/
setup_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import glob
import os
import pickle
import torch
from peft.auto import AutoPeftModelForCausalLM
from pytorch_lightning import seed_everything
from transformers import AutoTokenizer, BitsAndBytesConfig
from config import EXPERIMENT_ROOT, STATE_DICT_KEY, args, set_template
from dataloader import LRUDataloader
from demo.index import create_index
from llamarec_datasets import dataset_factory
from model import LRURec
from model.llm import AutoModelForCausalLMPatched
def main(args):
seed_everything(args.seed)
dataset = dataset_factory(args)
dataset_dict = dataset.load_dataset()
if not os.path.exists("demo/dataset_meta.pkl"):
combined_meta = {
"meta": dataset_dict["meta"],
"spotify_meta": dataset_dict["spotify_meta"],
}
with open("demo/dataset_meta.pkl", "wb") as f:
pickle.dump(combined_meta, f)
create_index("demo/dataset_meta.pkl", "demo/indexdir")
LRUDataloader(args, dataset)
model = LRURec(args)
export_root = EXPERIMENT_ROOT + "/" + args.model_code + "/" + args.dataset_code
print("Loading retriever model at ", export_root)
retriever_path = os.path.join(export_root, "models", "best_acc_model.pth")
if not retriever_path:
print("Retriever model not found.")
return
best_model_dict = torch.load(retriever_path).get(STATE_DICT_KEY)
model.load_state_dict(best_model_dict)
model.eval()
model_scripted = torch.jit.script(model)
model_scripted.save("demo/retriever.pth")
print("Retriever model saved to demo directory.")
export_root = (
EXPERIMENT_ROOT
+ "/"
+ args.llm_base_model.split("/")[-1]
+ "/"
+ args.dataset_code
)
checkpoint_dirs = sorted(glob.glob(f"{export_root}/checkpoint-*"))
if checkpoint_dirs:
export_root = checkpoint_dirs[0]
else:
print("LLM model not found.")
return
print("Loading LLM model at ", export_root)
llm_path = "demo/llm"
device_map = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_bf16_supported() else None
),
)
base_model = AutoModelForCausalLMPatched.from_pretrained(
args.llm_base_model,
quantization_config=bnb_config,
device_map=device_map,
attn_implementation="flash_attention_2",
)
base_model.eval()
base_model.save_pretrained(llm_path)
print("Base LLM saved to demo directory.")
tokenizer = AutoTokenizer.from_pretrained(export_root)
tokenizer.save_pretrained(llm_path)
print("Tokenizer saved to demo directory.")
peft_model = AutoPeftModelForCausalLM.from_pretrained(
export_root,
quantization_config=bnb_config,
device_map=device_map,
attn_implementation="flash_attention_2",
)
peft_model.peft_config["default"].base_model_name_or_path = llm_path.split("/")[-1]
peft_model.eval()
peft_model.save_pretrained(llm_path)
print("LoRA adapter saved to demo directory.")
print("Demo setup complete.")
if __name__ == "__main__":
args.model_code = "lru"
args.dataset_code = "music"
set_template(args)
main(args)