Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Apple Silicon #1289

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
minor fixes and enhancements
- lazy loading of model
- minor refactoring
- optimizers and lr schedulers
- gc
- should improve memory consumption
  • Loading branch information
shashikanth-a committed Nov 25, 2024
commit 066c2278215367de20d712000827e9ad2bc793be
42 changes: 11 additions & 31 deletions unsloth-cli.py
Original file line number Diff line number Diff line change
@@ -43,14 +43,9 @@ def run(args):
import logging
logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
if has_mps:
import mlx.optimizers as optim
import mlx.core as mx
from unsloth.models import mlx_utils as lora_utils
from unsloth.models import mlx_lora
import numpy as np
from unsloth.models.mlx_models import LoRALinear
from mlx.utils import tree_flatten
from pathlib import Path
from unsloth.mlx import mlx_utils
from unsloth.mlx import lora as mlx_lora
import gc

if not has_mps:
# Load model and tokenizer
@@ -61,26 +56,9 @@ def run(args):
load_in_4bit=args.load_in_4bit,
)
else:
np.random.seed(args.seed)

# Building tokenizer_config
tokenizer_config = {}

print("Loading pretrained model")
model, tokenizer, config = lora_utils.load(args.model_name, tokenizer_config)
# Freeze all layers other than LORA linears
model.freeze()
for l in model.model.layers[len(model.model.layers) - args.r :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)

p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")

model, tokenizer, config = mlx_utils.load_pretrained(args.model_name)

# Configure PEFT model
if not has_mps:
model = FastLanguageModel.get_peft_model(
@@ -159,8 +137,7 @@ def formatting_prompts_func(examples):
trainer_stats = trainer.train()
else:
datasets = dataset.train_test_split(test_size=0.1)
opt = optim.Adam(learning_rate=args.learning_rate)
mlx_lora.train(model, datasets["train"], datasets["test"], opt, mlx_lora.loss, tokenizer, args)
mlx_lora.train_model(args,model,tokenizer, datasets["train"], datasets["test"])


# Save model
@@ -192,8 +169,11 @@ def formatting_prompts_func(examples):
)
else:
if has_mps:
mx.savez(Path(args.save_path,args.adapter_file), **dict(tree_flatten(model.trainable_parameters())))
model.save_merged_model(args)
del model
gc.collect()
mlx_utils.save_merged_model(args)
if args.push_model:
mlx_utils.push_to_hub(args,config["_name_or_path"],config["model_type"])
else:
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
if args.push_model:
105 changes: 105 additions & 0 deletions unsloth/mlx/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from pathlib import Path
import math
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.core as mx
from .trainer.trainer import TrainingArgs, TrainingCallback, train
from .trainer.utils import (
build_schedule,
linear_to_lora_layers,
print_trainable_parameters,
)
from .mlx_utils import save_config

def train_model(
args,
model: nn.Module,
tokenizer,
train_set,
valid_set,
training_callback: TrainingCallback = None,
):
model.freeze()
linear_to_lora_layers(
model,
min(args.r,len(model.layers)/2),
{"rank": args.r, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "scale": float(args.lora_alpha)/math.sqrt(float(args.r)) if args.use_rslora else float(args.lora_alpha)/float(args.r)},
)
print_trainable_parameters(model)

adapter_path = Path(args.save_path)
adapter_path.mkdir(parents=True, exist_ok=True)

adapter_file = adapter_path / args.adapter_file
config = {
"num_layers" : min(args.r,len(model.layers)/2),
"lora_parameters" : {"rank": args.r, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "scale": float(args.lora_alpha)/math.sqrt(float(args.r)) if args.use_rslora else float(args.lora_alpha)/float(args.r)}
}
save_config(config, adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
batch_size=args.per_device_train_batch_size,
iters=args.max_steps,
val_batches=25,
steps_per_report=10,
steps_per_eval=200,
steps_per_save=100,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.use_gradient_checkpointing,
)

mx.random.seed(args.seed)
model.train()
if args.lr_scheduler_type == "linear":
arguments = [0.0,args.learning_rate,args.warmup_steps]
elif args.lr_scheduler_type == "exponential_decay":
arguments = [args.learning_rate,args.weight_decay]
elif args.lr_scheduler_type == "step_decay":
arguments = [args.learning_rate,args.weight_decay,args.warmup_steps]
elif args.lr_scheduler_type == "cosine_decay":
arguments = [args.learning_rate,args.max_steps]
else:
arguments = [args.learning_rate]

schedule_config = {
"name": "linear_schedule" if args.lr_scheduler_type == "linear" else args.lr_scheduler_type,
"warmup": args.warmup_steps,
"arguments": arguments,
}

lr = build_schedule(schedule_config) if args.lr_scheduler_type else args.learning_rate

if args.optim.lower().startswith("sgd"):
opt = optim.SGD(learning_rate=(lr), weight_decay=args.weight_decay)
elif args.optim.lower().startswith("rmsprop"):
opt = optim.RMSprop(learning_rate=(lr))
elif args.optim.lower().startswith("adagrad"):
opt = optim.Adagrad(learning_rate=(lr))
elif args.optim.lower().startswith("adaDelta"):
opt = optim.AdaDelta(learning_rate=(lr))
elif args.optim.lower().startswith("adamw"):
opt = optim.AdamW(learning_rate=(lr),weight_decay=args.weight_decay)
elif args.optim.lower().startswith("adam"):
opt = optim.Adam(learning_rate=(lr))
elif args.optim.lower().startswith("adamax"):
opt = optim.Adamax(learning_rate=(lr))
elif args.optim.lower().startswith("lion"):
opt = optim.Lion(learning_rate=(lr), weight_decay=args.weight_decay)
elif args.optim.lower().startswith("adafactor"):
opt = optim.Adafactor(learning_rate=(lr), weight_decay= args.weight_decay)
else:
raise ValueError("The Optimizer type provided is not supported")

# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)

Loading