Skip to content

Commit

Permalink
Update save.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jun 13, 2024
1 parent b312b3f commit 1601dca
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pickle
import gc
from transformers.models.llama.modeling_llama import logger
from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters
from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias
import subprocess
import psutil
import re
Expand Down Expand Up @@ -132,9 +132,10 @@ def _free_cached_model(model):

def _merge_lora(layer, name):

bias = None
if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)):
# Is LoRA so we need to merge!
W, quant_state, A, B, s = get_lora_parameters(layer)
W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer)
if quant_state is not None:
dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2]
W = fast_dequantize(W, quant_state)
Expand All @@ -156,7 +157,7 @@ def _merge_lora(layer, name):
W = W.t().to(dtype)
else:
W = layer.weight
return W
return W, bias
pass


Expand Down Expand Up @@ -527,7 +528,12 @@ def unsloth_save_model(
for item in LLAMA_WEIGHTS:
proj = eval(f"layer.{item}")
name = f"model.layers.{j}.{item}.weight"
W = _merge_lora(proj, name)
W, bias = _merge_lora(proj, name)

# Bias term
if bias is not None:
state_dict[f"model.layers.{j}.{item}.bias"] = bias
pass

if (torch.cuda.memory_allocated() + W.nbytes) < max_vram:
# Save to GPU memory
Expand Down

0 comments on commit 1601dca

Please sign in to comment.