Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Kron #843

Draft
wants to merge 67 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
86d6556
logs
Jan 1, 2024
8b0900f
try dataset path
WhenWen Jan 1, 2024
5bf8cd0
use the jax serialization manager for deser in an attempt to fix cras…
dlwh Oct 12, 2024
20a4568
cleaner data loader but ti doesn't help :(
dlwh Oct 13, 2024
2747705
ok this maybe fixed it?
dlwh Oct 14, 2024
538f0ed
cleanup
dlwh Oct 14, 2024
2c5ee4b
fix tests
dlwh Oct 14, 2024
ab543d6
fix what is probably the underlying problem
dlwh Oct 14, 2024
6395305
wip
dlwh Oct 14, 2024
b8f28f4
Merge branch 'main' of github.com:WhenWen/levanter-2024 into main
Dec 3, 2024
f98e376
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
e961c95
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
2b80af7
wip
dlwh Dec 3, 2024
87d7665
enough device puts and we're good
dlwh Dec 3, 2024
074d0ec
ok we're good
dlwh Dec 3, 2024
5668289
Merge remote-tracking branch 'origin/main' into WhenWen/main
dlwh Dec 3, 2024
d2d310e
Merge branch 'use_manager_deser' into WhenWen/main
dlwh Dec 3, 2024
722edaf
fix tree leaf stuff
dlwh Dec 3, 2024
5692611
add map_flattened_linear_layers use in muon
dlwh Dec 4, 2024
2f119bd
Merge remote-tracking branch 'origin/main' into muon
dlwh Dec 4, 2024
0f41ebb
adding kron file to optim
evanatyourservice Dec 4, 2024
fe3ecc9
testing 123
evanatyourservice Dec 5, 2024
37452c7
Update kron.py
evanatyourservice Dec 5, 2024
701956d
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
aac1cee
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
e44e7fa
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
476ba36
Update kron.py
evanatyourservice Dec 6, 2024
966b80e
expose precond lr and init
evanatyourservice Dec 7, 2024
6408d23
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Dec 7, 2024
91e29e7
Update kron.py
evanatyourservice Dec 10, 2024
6b9ce26
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 14, 2024
53efaa3
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 14, 2024
311da92
Update README.md
evanatyourservice Dec 14, 2024
33c17f4
Update kron.py
evanatyourservice Dec 14, 2024
8da6e34
Update kron.py
evanatyourservice Dec 14, 2024
5607cec
Update kron.py
evanatyourservice Dec 14, 2024
2fb6c34
trust remote code
evanatyourservice Dec 15, 2024
336e1e1
settings defaults
evanatyourservice Dec 15, 2024
a5ff351
no key, deterministic, pass all into cond, more sharding
evanatyourservice Dec 15, 2024
3a06e1c
set key in state
evanatyourservice Dec 15, 2024
f7f2382
whoops
evanatyourservice Dec 15, 2024
07781e6
small fix
evanatyourservice Dec 15, 2024
9ef0869
Update kron.py
evanatyourservice Dec 15, 2024
ed50cce
Update kron.py
evanatyourservice Dec 15, 2024
1dc0f43
settings
evanatyourservice Dec 15, 2024
f1c1b38
small fix in init sharding
evanatyourservice Dec 16, 2024
7a6f501
trying repl only
evanatyourservice Dec 16, 2024
3473eed
Revert "trying repl only"
evanatyourservice Dec 16, 2024
7684702
trying while loop
evanatyourservice Dec 16, 2024
d284518
trying more simple psgd kron version
evanatyourservice Dec 19, 2024
0c920b0
Update kron.py
evanatyourservice Dec 19, 2024
2feff32
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Dec 19, 2024
6a2e19f
trying simple version
evanatyourservice Dec 19, 2024
5108be0
take out unavailable args
evanatyourservice Dec 19, 2024
975a2d7
no extra args
evanatyourservice Dec 19, 2024
3fa70ab
trying this
evanatyourservice Dec 19, 2024
b62963e
Update kron.py
evanatyourservice Dec 19, 2024
5bafdcf
Revert "Update kron.py"
evanatyourservice Dec 19, 2024
c47c4c5
small fix
evanatyourservice Dec 19, 2024
ee747c0
settings
evanatyourservice Dec 19, 2024
59f2c10
small changes/moving to remote
evanatyourservice Dec 22, 2024
aa43e4f
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 22, 2024
25a2c20
simplified kron is working, need to test on larger pod
evanatyourservice Dec 22, 2024
88d49ed
Update kron.py
evanatyourservice Dec 23, 2024
4d630d8
get rid of norming and clipping in lieu of rms clip, retouches
evanatyourservice Dec 31, 2024
964cf19
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 31, 2024
bc3c7db
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Jan 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
expose precond lr and init
evanatyourservice committed Dec 7, 2024
commit 966b80e332509674b6cd8720ab89946572008f88
18 changes: 16 additions & 2 deletions src/levanter/optim/kron.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,8 @@ class KronConfig(OptimizerConfig):
- None: All preconditioners are triangular (default)
- 'one_diag': Largest/last dim per layer uses diagonal preconditioner
- 'all_diag': All preconditioners are diagonal
preconditioner_lr: Learning rate for preconditioner.
preconditioner_init_scale: Scale for preconditioner initialization.
mu_dtype: Dtype of the momentum buffer. Defaults to same dtype as parameters.
precond_dtype: Dtype of the preconditioners. Defaults to 'float32'.
precond_update_precision: Precision for matmul during preconditioner update.
@@ -66,6 +68,8 @@ class KronConfig(OptimizerConfig):
max_size_triangular: int = 10000
min_ndim_triangular: int = 2
memory_save_mode: Optional[str] = None
preconditioner_lr: float = 0.1
preconditioner_init_scale: float = 1.0
mu_dtype: Optional[Union[str, jnp.dtype]] = None
precond_dtype: Optional[Union[str, jnp.dtype]] = None
precond_update_precision: Optional[str] = "tensorfloat32"
@@ -103,6 +107,8 @@ def _optimizer(learning_rate) -> optax.GradientTransformation:
max_size_triangular=self.max_size_triangular,
min_ndim_triangular=self.min_ndim_triangular,
memory_save_mode=self.memory_save_mode,
preconditioner_lr=self.preconditioner_lr,
preconditioner_init_scale=self.preconditioner_init_scale,
mu_dtype=self.mu_dtype,
precond_dtype=self.precond_dtype,
precond_update_precision=self.precond_update_precision,
@@ -194,6 +200,8 @@ def scale_by_kron(
max_size_triangular: int = 8192,
min_ndim_triangular: int = 2,
memory_save_mode: Optional[str] = None,
preconditioner_lr: float = 0.1,
preconditioner_init_scale: float = 1.0,
mu_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_update_precision: Optional[str] = "tensorfloat32",
@@ -225,6 +233,8 @@ def scale_by_kron(
to set all preconditioners to be triangular, 'one_diag' sets the largest
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
to be diagonal.
preconditioner_lr: float, learning rate for preconditioner.
preconditioner_init_scale: float, scale for preconditioner initialization.
mu_dtype: optional str or jnp.dtype, dtype of the momentum buffer. Defaults to
same dtype as the parameters.
precond_dtype: optional str or jnp.dtype, dtype of the preconditioners. Defaults
@@ -258,8 +268,6 @@ def scale_by_kron(
"""
mu_dtype = canonicalize_dtype(mu_dtype)
precond_dtype = canonicalize_dtype(precond_dtype or jnp.float32)
preconditioner_lr = 0.1
preconditioner_init_scale = 1.0
lax_map = lax_map_scanned_layers
bs = lax_map_batch_size

@@ -976,6 +984,8 @@ def kron(
max_size_triangular: int = 8192,
min_ndim_triangular: int = 2,
memory_save_mode: Optional[str] = None,
preconditioner_lr: float = 0.1,
preconditioner_init_scale: float = 1.0,
mu_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_update_precision: Optional[str] = "tensorfloat32",
@@ -1011,6 +1021,8 @@ def kron(
to set all preconditioners to be triangular, 'one_diag' sets the largest
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
to be diagonal.
preconditioner_lr: float, learning rate for preconditioner.
preconditioner_init_scale: float, scale for preconditioner initialization.
mu_dtype: optional str or jnp.dtype, dtype of the momentum buffer. Defaults to
same dtype as the parameters.
precond_dtype: optional str or jnp.dtype, dtype of the preconditioners. Defaults
@@ -1050,6 +1062,8 @@ def kron(
max_size_triangular=max_size_triangular,
min_ndim_triangular=min_ndim_triangular,
memory_save_mode=memory_save_mode,
preconditioner_lr=preconditioner_lr,
preconditioner_init_scale=preconditioner_init_scale,
mu_dtype=mu_dtype,
precond_dtype=precond_dtype,
precond_update_precision=precond_update_precision,