Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweaks to Muon branch #831

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
86d6556
logs
Jan 1, 2024
8b0900f
try dataset path
WhenWen Jan 1, 2024
5bf8cd0
use the jax serialization manager for deser in an attempt to fix cras…
dlwh Oct 12, 2024
20a4568
cleaner data loader but ti doesn't help :(
dlwh Oct 13, 2024
2747705
ok this maybe fixed it?
dlwh Oct 14, 2024
538f0ed
cleanup
dlwh Oct 14, 2024
2c5ee4b
fix tests
dlwh Oct 14, 2024
ab543d6
fix what is probably the underlying problem
dlwh Oct 14, 2024
6395305
wip
dlwh Oct 14, 2024
b8f28f4
Merge branch 'main' of github.com:WhenWen/levanter-2024 into main
Dec 3, 2024
f98e376
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
e961c95
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
2b80af7
wip
dlwh Dec 3, 2024
87d7665
enough device puts and we're good
dlwh Dec 3, 2024
074d0ec
ok we're good
dlwh Dec 3, 2024
5668289
Merge remote-tracking branch 'origin/main' into WhenWen/main
dlwh Dec 3, 2024
d2d310e
Merge branch 'use_manager_deser' into WhenWen/main
dlwh Dec 3, 2024
722edaf
fix tree leaf stuff
dlwh Dec 3, 2024
5692611
add map_flattened_linear_layers use in muon
dlwh Dec 4, 2024
2f119bd
Merge remote-tracking branch 'origin/main' into muon
dlwh Dec 4, 2024
271479f
The training succeed and the model seems to be properly saved but hav…
WhenWen Dec 12, 2024
f36ff81
Merge remote-tracking branch 'origin/main' into muon
dlwh Dec 12, 2024
144e20f
fix deser for muon
dlwh Dec 12, 2024
083835d
Merge remote-tracking branch 'origin/muon' into muon
dlwh Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions config/llama2_100M_mars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data: !include data/dclm_gpt_neo.yaml
model:
type: llama
seq_len: 4096
hidden_dim: 768
intermediate_dim: 3072
num_layers: 12
num_heads: 12
num_kv_heads: 12
trainer:
tracker:
project: "levanter"
tags: ["pile", "llama"]
mp: p=f32,c=bfloat16
model_axis_size: 1
checkpointer:
keep:
- every: 1000
save_interval: 30m


train_batch_size: 1024
per_device_parallelism: 4 # set for v3 TPU
per_device_eval_parallelism: 4 # set a larger batch size for eval
num_train_steps: 50001
optimizer:
learning_rate: 4E-3 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.0
warmup: 2000
cooldown: 0.4
lr_schedule: constant
gamma: 0.025
type: mars
34 changes: 34 additions & 0 deletions config/llama2_100M_muon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data: !include data/dclm_gpt_neo.yaml
model:
type: llama
seq_len: 4096
hidden_dim: 768
intermediate_dim: 3072
num_layers: 12
num_heads: 12
num_kv_heads: 12
trainer:
tracker:
project: "levanter"
tags: ["pile", "llama"]
mp: p=f32,c=bfloat16
model_axis_size: 1
checkpointer:
keep:
- every: 1000
save_interval: 30m


train_batch_size: 1024
per_device_parallelism: 4 # set for v3 TPU
per_device_eval_parallelism: 4 # set a larger batch size for eval
num_train_steps: 50001
optimizer:
learning_rate: 2E-2 # set low for fine-tuning
weight_decay: 0
warmup: 0
cooldown: 0.1
lr_schedule: constant
min_lr_ratio: 0.0
max_grad_norm: 0.0
type: muon
10 changes: 10 additions & 0 deletions error_loading_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
eval $(ssh-agent -s)
bash infra/babysit-tpu-vm.sh muon-debug -z us-central2-b -t v4-128 --preemptible -- \
WANDB_API_KEY=[WANDB_API_KEY] \
bash levanter/infra/run.sh python \
levanter/src/levanter/main/train_lm.py \
--config_path levanter/config/llama2_100M_muon.yaml \
--trainer.checkpointer.base_path gs://marin-us-central2/scratch/kaiyue/checkpoints/muon/llama2_100M_constant \
--optimizer.type muon \
--trainer.num_train_steps 10000 \
--trainer.load_checkpoint_path gs://marin-us-central2/scratch/kaiyue/checkpoints/muon/llama2_100M_constant/tjo9vxfb/step-4000
8 changes: 8 additions & 0 deletions src/levanter/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
scale_by_sophia_g,
scale_by_sophia_h,
)
from .muon import (
MuonConfig,
ScaleByMuonState
)
from .mars import (
MarsConfig,
ScaleByMarsState
)
135 changes: 135 additions & 0 deletions src/levanter/optim/mars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import abc
import functools
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, TypeVar

import equinox as eqx
import jax
import jaxtyping
import optax
from jax import numpy as jnp
from jax.random import PRNGKey
from jaxtyping import PRNGKeyArray

import levanter.tracker
from levanter.optim.config import HessianOptConfig, OptimizerConfig
from levanter.optim.util import hvp, tree_gaussian_like
from levanter.utils.jax_utils import parameter_count, tree_filter_like


@OptimizerConfig.register_subclass("mars")
@dataclass
class MarsConfig(OptimizerConfig):
weight_decay: float = 0.1
beta1: float = 0.95
# cf https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.optim.DecoupledAdamW.html
# https://x.com/giffmana/status/1692641748445438301
beta2: float = 0.99
gamma: float = 0.025
epsilon: float = 1e-8
max_grad_norm: Optional[float] = 1.0
haps: Optional[list[int]] = None
schedule_list: Optional[list[str]] = None

def build(self, num_train_steps):
"""Creates the optimizer"""
# indirection makes it work with optax.inject_hyperparams so we can log the learning rate
def _optimizer(learning_rate):
components = []


components.append(scale_by_mars(self.beta1, self.beta2, self.gamma, self.epsilon, max_grad_norm = self.max_grad_norm))

if self.weight_decay > 0:
components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask()))

# - learning rate for descent
components.append(optax.scale(-learning_rate))

optimizer = optax.chain(*components)

return optimizer

return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps))

from optax import tree_utils as otu
import jax
import jax.numpy as jnp
from jax import jit


import chex

class ScaleByMarsState(NamedTuple):
"""State for the Mars algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
mog: optax.Updates


def scale_by_mars(
b1: float = 0.9,
b2: float = 0.999,
gamma: float = 0.05,
eps: float = 1e-8,
eps_root: float = 0.0,
max_grad_norm: float = 0.0,
mu_dtype: Optional[Any] = None
) -> optax.GradientTransformation:
r"""Rescale updates according to the MARS algorithm.
https://arxiv.org/abs/2411.10438
See :func:optax.adam for more details.

Args:
b1: Decay rate for the exponentially weighted average of grads.
b2: Decay rate for the exponentially weighted average of squared grads.
gamma: control the scale of variance reduction
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
mu_dtype: Optional dtype to be used for the first order accumulator; if
None then the dtype is inferred from params and updates.
Returns:
A :class:optax.GradientTransformation object.
"""

mu_dtype = jax.dtypes.canonicalize_dtype(mu_dtype)

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
nu = otu.tree_zeros_like(params) # Second moment
mog = otu.tree_zeros_like(params, dtype=mu_dtype) # gradient from
return ScaleByMarsState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, mog = mog)

def update_fn(updates, state, params=None):
c = jax.tree.map(
lambda og, g: None if g is None else g + (gamma * b1 / (1 - b1)) * (g - og),
state.mog,
updates,
is_leaf=lambda x: x is None,
)
if max_grad_norm:
g_norm = optax.global_norm(c)
scale = jnp.minimum(1.0, max_grad_norm / (g_norm + 1e-6))
c = jax.tree_map(lambda g: None if g is None else g * scale,
c,
is_leaf=lambda x: x is None
)
mu = otu.tree_update_moment(c, state.mu, b1, 1)
nu = otu.tree_update_moment_per_elem_norm(c, state.nu, b2, 2)
count_inc = optax.safe_increment(state.count)
mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
# Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ
# Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is
# unclear why. Other Nadam implementations also omit the extra b2 factor.
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
adam_updates = jax.tree.map(
lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps),
mu_hat,
nu_hat,
is_leaf=lambda x: x is None,
)
mu = otu.tree_cast(mu, mu_dtype)
return adam_updates, ScaleByMarsState(count=count_inc, mu=mu, nu=nu, mog = updates)
return optax.GradientTransformation(init_fn, update_fn)
Loading
Loading