Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Kron #843

Draft
wants to merge 67 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
86d6556
logs
Jan 1, 2024
8b0900f
try dataset path
WhenWen Jan 1, 2024
5bf8cd0
use the jax serialization manager for deser in an attempt to fix cras…
dlwh Oct 12, 2024
20a4568
cleaner data loader but ti doesn't help :(
dlwh Oct 13, 2024
2747705
ok this maybe fixed it?
dlwh Oct 14, 2024
538f0ed
cleanup
dlwh Oct 14, 2024
2c5ee4b
fix tests
dlwh Oct 14, 2024
ab543d6
fix what is probably the underlying problem
dlwh Oct 14, 2024
6395305
wip
dlwh Oct 14, 2024
b8f28f4
Merge branch 'main' of github.com:WhenWen/levanter-2024 into main
Dec 3, 2024
f98e376
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
e961c95
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
2b80af7
wip
dlwh Dec 3, 2024
87d7665
enough device puts and we're good
dlwh Dec 3, 2024
074d0ec
ok we're good
dlwh Dec 3, 2024
5668289
Merge remote-tracking branch 'origin/main' into WhenWen/main
dlwh Dec 3, 2024
d2d310e
Merge branch 'use_manager_deser' into WhenWen/main
dlwh Dec 3, 2024
722edaf
fix tree leaf stuff
dlwh Dec 3, 2024
5692611
add map_flattened_linear_layers use in muon
dlwh Dec 4, 2024
2f119bd
Merge remote-tracking branch 'origin/main' into muon
dlwh Dec 4, 2024
0f41ebb
adding kron file to optim
evanatyourservice Dec 4, 2024
fe3ecc9
testing 123
evanatyourservice Dec 5, 2024
37452c7
Update kron.py
evanatyourservice Dec 5, 2024
701956d
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
aac1cee
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
e44e7fa
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
476ba36
Update kron.py
evanatyourservice Dec 6, 2024
966b80e
expose precond lr and init
evanatyourservice Dec 7, 2024
6408d23
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Dec 7, 2024
91e29e7
Update kron.py
evanatyourservice Dec 10, 2024
6b9ce26
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 14, 2024
53efaa3
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 14, 2024
311da92
Update README.md
evanatyourservice Dec 14, 2024
33c17f4
Update kron.py
evanatyourservice Dec 14, 2024
8da6e34
Update kron.py
evanatyourservice Dec 14, 2024
5607cec
Update kron.py
evanatyourservice Dec 14, 2024
2fb6c34
trust remote code
evanatyourservice Dec 15, 2024
336e1e1
settings defaults
evanatyourservice Dec 15, 2024
a5ff351
no key, deterministic, pass all into cond, more sharding
evanatyourservice Dec 15, 2024
3a06e1c
set key in state
evanatyourservice Dec 15, 2024
f7f2382
whoops
evanatyourservice Dec 15, 2024
07781e6
small fix
evanatyourservice Dec 15, 2024
9ef0869
Update kron.py
evanatyourservice Dec 15, 2024
ed50cce
Update kron.py
evanatyourservice Dec 15, 2024
1dc0f43
settings
evanatyourservice Dec 15, 2024
f1c1b38
small fix in init sharding
evanatyourservice Dec 16, 2024
7a6f501
trying repl only
evanatyourservice Dec 16, 2024
3473eed
Revert "trying repl only"
evanatyourservice Dec 16, 2024
7684702
trying while loop
evanatyourservice Dec 16, 2024
d284518
trying more simple psgd kron version
evanatyourservice Dec 19, 2024
0c920b0
Update kron.py
evanatyourservice Dec 19, 2024
2feff32
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Dec 19, 2024
6a2e19f
trying simple version
evanatyourservice Dec 19, 2024
5108be0
take out unavailable args
evanatyourservice Dec 19, 2024
975a2d7
no extra args
evanatyourservice Dec 19, 2024
3fa70ab
trying this
evanatyourservice Dec 19, 2024
b62963e
Update kron.py
evanatyourservice Dec 19, 2024
5bafdcf
Revert "Update kron.py"
evanatyourservice Dec 19, 2024
c47c4c5
small fix
evanatyourservice Dec 19, 2024
ee747c0
settings
evanatyourservice Dec 19, 2024
59f2c10
small changes/moving to remote
evanatyourservice Dec 22, 2024
aa43e4f
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 22, 2024
25a2c20
simplified kron is working, need to test on larger pod
evanatyourservice Dec 22, 2024
88d49ed
Update kron.py
evanatyourservice Dec 23, 2024
4d630d8
get rid of norming and clipping in lieu of rms clip, retouches
evanatyourservice Dec 31, 2024
964cf19
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 31, 2024
bc3c7db
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Jan 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
small changes/moving to remote
evanatyourservice committed Dec 22, 2024
commit 59f2c10d40e31ffb732d66b42afd4bb4ac6dfe37
37 changes: 11 additions & 26 deletions src/levanter/optim/kron.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ class KronConfig(OptimizerConfig):
weight_decay: Weight decay coefficient.
max_grad_norm: Optional gradient norm clipping value.
normalize_grads: Whether to normalize the incoming gradients to unit norm layer-wise.
Can help with stability.
Can help with stability but likely not necessary in this scenario.
preconditioner_update_probability: Final probability of updating the preconditioner. Default
is 0.05 (update every 20 steps). The `precond_update_prob_schedule` holds probability at
1.0 for `update_prob_flat_start` steps before annealing exponentially down to this
@@ -50,20 +50,14 @@ class KronConfig(OptimizerConfig):
lax_map_batch_size: Batch size for lax.map, see JAX docs for more info.
merge_small_dims: Whether to merge small dimensions to improve preconditioner efficiency.
target_merged_dim_size: Target size of merged dimensions.
partition_grads_into_blocks: Whether to partition grads into chunks of size block_size
for efficiency.
block_size: Block size to use for partitioning grads.
params_sharding: Pytree same structure as params of jax.sharding.PartitionSpec.
preconditioner_sharding: PartitionSpec for preconditioner matrices. Best practice is to
shard first dimension across fsdp-like mesh axis, or largest/most common axis in params.
Example: PartitionSpec('fsdp') or PartitionSpec('fsdp', 'tp').
"""
# some of these are changed from kron defaults to better suit levanter
beta1: float = 0.9
weight_decay: float = 0.1
max_grad_norm: Optional[float] = 0.0
normalize_grads: bool = True
preconditioner_update_probability: float = 0.03
normalize_grads: bool = False
preconditioner_update_probability: float = 0.05
update_prob_flat_start: int = 1000
max_size_triangular: int = 25000
min_ndim_triangular: int = 2
@@ -72,27 +66,19 @@ class KronConfig(OptimizerConfig):
preconditioner_init_scale: float = 1.0
mu_dtype: Optional[Union[str, jnp.dtype]] = None
precond_dtype: Optional[Union[str, jnp.dtype]] = None
precond_update_precision: Optional[str] = "float32"
precond_update_precision: Optional[str] = "tensorfloat32"
precond_grads_precision: Optional[str] = None
scanned_layers: Optional[optax.Params] = None
lax_map_scanned_layers: bool = False
lax_map_batch_size: int = 8
merge_small_dims: bool = True
target_merged_dim_size: int = 8192
partition_grads_into_blocks: bool = True
block_size: int = 256
params_sharding: Optional[Any] = None
preconditioner_sharding: Optional[tuple[str | None, str | None]] = None

def build(self, num_train_steps):
"""Creates the optimizer."""

def _optimizer(learning_rate) -> optax.GradientTransformation:
precond_partition_spec = (
PartitionSpec(*self.preconditioner_sharding)
if self.preconditioner_sharding is not None
else None
)
components = []
if self.max_grad_norm and not self.normalize_grads:
components.append(optax.clip_by_global_norm(self.max_grad_norm))
@@ -116,14 +102,15 @@ def _optimizer(learning_rate) -> optax.GradientTransformation:
scanned_layers=self.scanned_layers,
lax_map_scanned_layers=self.lax_map_scanned_layers,
lax_map_batch_size=self.lax_map_batch_size,
# merge_small_dims=self.merge_small_dims,
# target_merged_dim_size=self.target_merged_dim_size,
# partition_grads_into_blocks=self.partition_grads_into_blocks,
# block_size=self.block_size,
# params_sharding=self.params_sharding,
# preconditioner_sharding=precond_partition_spec,
merge_small_dims=self.merge_small_dims,
target_merged_dim_size=self.target_merged_dim_size,
params_sharding=self.params_sharding,
)
)
# PSGD's output should be RMS=1.0, so we can clip at 1.1 in case of
# gradient spike. This is better than clipping incoming grads because this
# gets rid of information for the preconditioner.
components.append(optax.clip_by_block_rms(1.1))
if self.weight_decay > 0:
components.append(
optax.add_decayed_weights(
@@ -143,11 +130,9 @@ def _optimizer(learning_rate) -> optax.GradientTransformation:
import string
import numpy as np

import chex
import jax
from jax import vmap
import jax.numpy as jnp
import flax.linen as nn
from optax import tree_utils as otu
from optax._src import base, transform
from optax._src.numerics import safe_int32_increment