Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new ResourceEnvs from Haliax #444

Open
wants to merge 236 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
236 commits
Select commit Hold shift + click to select a range
fff4dfb
wip
dlwh Oct 25, 2023
740ad68
wip
dlwh Nov 7, 2023
4ad74a6
almost got new logger working
dlwh Nov 7, 2023
ad708e3
move the metrics stuff to its own file
dlwh Nov 8, 2023
6930fa9
refactor and move stuff around
dlwh Nov 8, 2023
abf7ec3
use generic infrastructure for summary
dlwh Nov 8, 2023
547cea8
wip towards a clean tracker package
dlwh Nov 8, 2023
2f481ed
wip
dlwh Nov 9, 2023
0b080fb
remove more wandb deps
dlwh Nov 9, 2023
a324ae5
tiny cleanup
dlwh Nov 9, 2023
cfdcbb9
add some tests
dlwh Nov 9, 2023
2ddc558
migrate alpaca-lora to new logger
dlwh Nov 9, 2023
9b0df08
sort of get tb to work
dlwh Nov 10, 2023
4fd2526
wip
dlwh Nov 14, 2023
a608a65
wip
dlwh Nov 16, 2023
176e5fa
Merge remote-tracking branch 'origin/main' into generic_logger
dlwh Nov 17, 2023
8d34f6f
update configs, expose a method to find trackers
dlwh Nov 17, 2023
42d7f2c
use `trainer` more to set logging
dlwh Nov 17, 2023
b887761
test the tracker get name stuff
dlwh Nov 17, 2023
3ebd161
minor
dlwh Nov 17, 2023
0d2efbc
making speccing the loss function simpler
dlwh Nov 18, 2023
f085287
stop requiring a loss function for every model definition
dlwh Nov 18, 2023
f21cf4b
wip
dlwh Nov 19, 2023
01c8b87
jkacjkac
dlwh Nov 19, 2023
e374697
tweak
dlwh Nov 22, 2023
921acf8
register default hooks by default...
dlwh Nov 22, 2023
c8a5d6c
wip
dlwh Nov 24, 2023
639d334
make it so we can evaluate if we have a cache but no sources
dlwh Nov 24, 2023
a3fdbaf
Merge branch 'cache_only' into extensible_trainer
dlwh Nov 24, 2023
ec35e9b
about got the checkpoint refactor done
dlwh Nov 25, 2023
ed13502
about got the checkpoint refactor done
dlwh Nov 25, 2023
634407e
minor dead code removal
dlwh Nov 25, 2023
4208e03
fix tests
dlwh Nov 26, 2023
9584884
cleanup
dlwh Nov 26, 2023
5a18678
cleanup
dlwh Nov 26, 2023
c355106
minor
dlwh Nov 26, 2023
7a2ffc3
Merge branch 'extensible_trainer' into doremi
dlwh Nov 26, 2023
d2e0de1
wip
dlwh Nov 26, 2023
be99631
register default hooks by default...
dlwh Nov 22, 2023
5d033eb
wip
dlwh Nov 24, 2023
c4a9160
make it so we can evaluate if we have a cache but no sources
dlwh Nov 24, 2023
b888065
about got the checkpoint refactor done
dlwh Nov 25, 2023
c47ae97
about got the checkpoint refactor done
dlwh Nov 25, 2023
f0613c7
minor dead code removal
dlwh Nov 25, 2023
85c5678
fix tests
dlwh Nov 26, 2023
8f84822
cleanup
dlwh Nov 26, 2023
e54bad0
cleanup
dlwh Nov 26, 2023
85dd89b
minor
dlwh Nov 26, 2023
c61824e
generalize and extract the checkpoint loading logic so it can be used…
dlwh Nov 27, 2023
7391475
Revert "Temporarily Revert "Generic Tracker interface, support for TB…
dlwh Nov 28, 2023
2387f26
wip
dlwh Nov 28, 2023
6446bc0
just about workable logger stuff
dlwh Nov 28, 2023
1b821d1
fix logging of config with a new levanter.initialize
dlwh Nov 28, 2023
afb6459
missed a sopt
dlwh Nov 28, 2023
9d916bd
on second thought, don't use tb in small_fast
dlwh Nov 29, 2023
3d67552
Merge remote-tracking branch 'origin/dev' into extensible_trainer
dlwh Nov 30, 2023
4d8cd68
main->dev (#375)
dlwh Dec 1, 2023
272e1e1
Merge remote-tracking branch 'origin/dev' into extensible_trainer
dlwh Dec 1, 2023
48ccdd3
Merge remote-tracking branch 'origin/main' into dev
dlwh Dec 1, 2023
3b27a08
supporting new trainer in gsm8k example
dlwh Dec 1, 2023
dcbed88
Merge branch 'dev' into extensible_trainer
dlwh Dec 2, 2023
bbac4ef
Merge remote-tracking branch 'origin/main' into extensible_trainer
dlwh Dec 2, 2023
f2842e9
Add Sophia-H, some WIP support for Sophia-G (#372)
dlwh Dec 7, 2023
6d6ae21
Merge remote-tracking branch 'origin/main' into dev
dlwh Dec 10, 2023
83bea6e
fix missing test changes
dlwh Dec 10, 2023
92a615f
should use a tempdir
dlwh Dec 11, 2023
cbee427
update gsm8k lora for sophia refactors
dlwh Dec 11, 2023
e048581
Allow val change wandb dev (#384)
dlwh Dec 13, 2023
2bdf08b
oops
dlwh Dec 13, 2023
8f4aff3
do loss in fp32
dlwh Dec 14, 2023
91eb588
Merge remote-tracking branch 'origin/dev' into extensible_trainer
dlwh Dec 17, 2023
2002832
more dead code removal
dlwh Dec 17, 2023
efa70a1
refix merge issues
dlwh Dec 17, 2023
4cca0d1
refix merge issues
dlwh Dec 17, 2023
15e223d
Merge remote-tracking branch 'origin/main' into dev
dlwh Dec 19, 2023
904497b
Merge branch 'dev' into extensible_trainer
dlwh Dec 19, 2023
2a90f57
allow train_batch_size to be -1 if per_device_parallelism isn't -1
dlwh Dec 19, 2023
f05739a
wip
dlwh Dec 21, 2023
321bb30
Merge remote-tracking branch 'origin/main' into extensible_trainer
dlwh Dec 21, 2023
38db3d5
fix performance regression in trainer.py
dlwh Dec 21, 2023
9b2813b
wth
dlwh Dec 21, 2023
e014c45
mdkladmlkad
dlwh Dec 21, 2023
95a391f
jfakmfa
dlwh Dec 21, 2023
9f40f10
try this other approach to steps in TrainerState
dlwh Dec 21, 2023
6df53f4
fix checkpoint tests
dlwh Dec 21, 2023
94aa8fa
fix gsm8k
dlwh Dec 21, 2023
5af6cb2
update for new Haliax reduction functions
dlwh Dec 24, 2023
e2b086f
Merge branch 'extensible_trainer' into doremi
dlwh Dec 24, 2023
84d3b33
wip
dlwh Dec 24, 2023
85b42b0
refactor grad_accum to have a separate microbatched
dlwh Dec 25, 2023
c47c188
remove accumulate_gradients_sharded and just use microbatched directly
dlwh Dec 27, 2023
70b766f
add dtype for grad accum
dlwh Dec 27, 2023
57725ea
small refactor
dlwh Dec 27, 2023
85f777b
small refactor
dlwh Dec 27, 2023
f8d98fc
fix key handling in grad accum
dlwh Dec 27, 2023
5a8c77a
make sophia work with non-trainables again
dlwh Dec 28, 2023
ff59e51
factor out some methods in train_step
dlwh Dec 28, 2023
c1718dd
Merge branch 'extensible_trainer' into doremi
dlwh Dec 28, 2023
d7a060d
make the initialize_from logic just use load_checkpoint_or_initialize
dlwh Dec 29, 2023
8c44e64
on second thought load_from_checkpoint_or_initialize is the wrong abs…
dlwh Dec 30, 2023
72f1e47
wip
dlwh Dec 30, 2023
add3df4
on second thought load_from_checkpoint_or_initialize is the wrong abs…
dlwh Dec 30, 2023
3ba7bf1
Merge branch 'extensible_trainer' into doremi
dlwh Dec 30, 2023
b6535b5
wip factoring out the initial state stuff, again
dlwh Dec 30, 2023
0d6f357
almost ready to try out doremi
dlwh Dec 30, 2023
7395e3c
almost ready to try out doremi
dlwh Jan 2, 2024
08996e6
cleanup typing.overloads
dlwh Jan 3, 2024
710900c
use auto_sharded internally, undeprecate it b/c it has a point
dlwh Jan 3, 2024
5f9d96d
fix docs
dlwh Jan 4, 2024
04a74a1
use new dot syntax in doremi
dlwh Jan 4, 2024
3249ca1
Merge remote-tracking branch 'origin/main' into doremi
dlwh Jan 8, 2024
6a20c95
fix mixture init with prngkey
dlwh Jan 9, 2024
fd6d343
add a simple InMemoryDataset that takes a list
dlwh Jan 9, 2024
f5b8d00
make keyiterator support just an int seed
dlwh Jan 9, 2024
288e7fb
dumb bug in grad accum
dlwh Jan 9, 2024
c4da125
fix some dumb bugs in new trainer
dlwh Jan 9, 2024
9257597
test for doremi and associated fixes
dlwh Jan 9, 2024
317b10d
depend on haliax dev for levanter dev
dlwh Jan 9, 2024
e4d1385
fix gsm8k_lora
dlwh Jan 9, 2024
ddcdac7
add a small_pile configuration
dlwh Jan 9, 2024
792f769
make it len 2048
dlwh Jan 9, 2024
e16b3af
add doremi main
dlwh Jan 10, 2024
a272ca9
we install haliax from source with the pyprojec.toml
dlwh Jan 10, 2024
e8d4b9d
fix doremi test when doing multidevice
dlwh Jan 10, 2024
5c489c1
add a pile_mixture.yaml
dlwh Jan 10, 2024
1672148
add a config for the small pile mixture
dlwh Jan 10, 2024
f485c5f
reduce default rows per chunk and see if that helps with these big su…
dlwh Jan 10, 2024
b2d8a58
add some more logging to see if we can figure out why it's running ou…
dlwh Jan 10, 2024
f76e466
add some more logging to see if we can figure out why it's running ou…
dlwh Jan 11, 2024
fc78716
dumb
dlwh Jan 11, 2024
4927f67
don't run the slow tests in CI
dlwh Jan 11, 2024
1ceb00a
wip
dlwh Jan 12, 2024
bc7108c
move the script, make it read off fsspec
dlwh Jan 13, 2024
69ca4a4
update for reverted Haliax change
dlwh Jan 13, 2024
ff5cb6d
update for reverted Haliax change
dlwh Jan 13, 2024
d6bf2c0
update paths for pile mixture
dlwh Jan 15, 2024
cc6044c
fix new import
dlwh Jan 15, 2024
415158a
sigh
dlwh Jan 15, 2024
d2a90ae
isjfo
dlwh Jan 15, 2024
058a9e0
mdklmdlm
dlwh Jan 15, 2024
9f16fbe
make logging list names of caches
dlwh Jan 15, 2024
b80ef6a
lower resource requirements to see if this gets us processing faster
dlwh Jan 15, 2024
6983ff0
let's make the chunkcachebuilders free
dlwh Jan 15, 2024
5f42ad8
minimize use of optax internals
dlwh Jan 15, 2024
e6e8d27
fix a crash i don't understand
dlwh Jan 16, 2024
ab29e92
let's reduce requirements some more to see if we can keep everything …
dlwh Jan 16, 2024
83f0616
let's reduce requirements some more to see if we can keep everything …
dlwh Jan 16, 2024
def45cc
silly
dlwh Jan 16, 2024
de821ca
ok so we're ok maybe
dlwh Jan 16, 2024
cbddab8
don't fetch local
dlwh Jan 16, 2024
5d0f987
wtf
dlwh Jan 16, 2024
13cc556
what
dlwh Jan 16, 2024
5afac01
ok, think we figured it out
dlwh Jan 16, 2024
41ac362
less logging
dlwh Jan 16, 2024
c621a08
toward turning the reader process into an actor too
dlwh Jan 17, 2024
4d92af9
did we do it?
dlwh Jan 17, 2024
257dfa7
wandb: only force a step if commit is true
dlwh Jan 17, 2024
23865a1
don't crash if n == 0
dlwh Jan 17, 2024
70c00f1
wandb: maybe this gives the behavior i want?
dlwh Jan 17, 2024
4c54365
mklafmlkafml
dlwh Jan 17, 2024
1edeeef
Merge branch 'main' into dev
dlwh Jan 17, 2024
6148381
minimize use of optax internals
dlwh Jan 15, 2024
8274cad
what
dlwh Jan 18, 2024
b980c9f
actually this is probably better
dlwh Jan 18, 2024
36f25a0
actually this is probably better
dlwh Jan 18, 2024
4da7112
dumb
dlwh Jan 18, 2024
6147520
mkladmlkad
dlwh Jan 18, 2024
e166a78
fix key order for doremi
dlwh Jan 18, 2024
e6b581b
remove excess log
dlwh Jan 18, 2024
8c64be5
remove a redundant log message
dlwh Jan 18, 2024
e89e709
fixed more bugs
dlwh Jan 18, 2024
33600fd
almost there
dlwh Jan 18, 2024
efbdd31
don't log a value for domains with no data on a step
dlwh Jan 18, 2024
a810242
bring over the trainer-abstraction doc
dlwh Jan 30, 2024
e49fb38
remove the wrapped loss_fn thing from trainer
dlwh Jan 30, 2024
13dc392
factor out a take_opt_step. need to decide where to put it
dlwh Jan 30, 2024
514da05
explicitly expose microbatch_size, use it in microbatched
dlwh Jan 30, 2024
f797a85
comment about custom_jvp on microbatched
dlwh Jan 31, 2024
4301930
unneeded cast
dlwh Jan 31, 2024
d3416b1
rename to mixed-precision.md
dlwh Jan 31, 2024
9552909
cleanup ctors for BatchLoaders some
dlwh Jan 31, 2024
888d35e
misc cleanup
dlwh Jan 31, 2024
49a409b
wip
dlwh Jan 31, 2024
78d9342
stable point: migrating to resourceenvs
dlwh Jan 31, 2024
8e4e183
require the jamp branch
dlwh Jan 31, 2024
9a0ea6d
knknajkdnjakd
dlwh Jan 31, 2024
7c19f47
try this?
dlwh Jan 31, 2024
d98a885
cleanup and explain the issue
dlwh Jan 31, 2024
015dfb3
see if we get the just-in-time conversion to bf16 that we want
dlwh Jan 31, 2024
cddaf20
wtf
dlwh Feb 1, 2024
27949f8
bypass microbatching if we don't need it?
dlwh Feb 1, 2024
c3a9ce1
switch to using hnn.Embedding in gpt2, which means we get the mixed p…
dlwh Feb 1, 2024
0e91352
switch to using compute_envs where posisble use .shard instead
dlwh Feb 1, 2024
b57e1c7
please pre-commit
dlwh Feb 1, 2024
7fd46cb
ok maybe we can do it?
dlwh Feb 1, 2024
2ca4d97
sigh
dlwh Feb 1, 2024
b1e99e5
Merge branch 'dev' into use_jamp
dlwh Feb 1, 2024
a237a57
fix test_weight_decay_mask.py
dlwh Feb 1, 2024
5282694
use param_env everywhere
dlwh Feb 1, 2024
a013c4c
makldmlkad
dlwh Feb 1, 2024
3049f89
Merge remote-tracking branch 'origin/main' into use_jamp
dlwh Feb 2, 2024
e4fcd67
Merge remote-tracking branch 'origin/main' into dev
dlwh Feb 2, 2024
4a8d07a
Merge branch 'dev' into use_jamp
dlwh Feb 2, 2024
1983a1f
Merge remote-tracking branch 'origin/main' into dev
dlwh Feb 2, 2024
5312b87
Merge remote-tracking branch 'origin/main' into dev
dlwh Feb 2, 2024
90ed9cd
Merge remote-tracking branch 'origin/main' into use_jamp
dlwh Feb 2, 2024
74096fa
Merge branch 'dev' into use_jamp
dlwh Feb 2, 2024
58ca1d7
wip debugging devices
dlwh Feb 2, 2024
6ee6d8f
let's try this?
dlwh Feb 2, 2024
b485673
so confused
dlwh Feb 2, 2024
de3162b
sigh
dlwh Feb 2, 2024
343367f
ok i think i got it
dlwh Feb 2, 2024
1198bb2
Merge branch 'dev' into simple_first_cleanup
dlwh Feb 3, 2024
a2d5934
Merge branch 'simple_first_cleanup' into use_jamp
dlwh Feb 3, 2024
71b755e
wtf
dlwh Feb 5, 2024
07b5797
this async seems like a bad idea
dlwh Feb 5, 2024
b6e0c1d
log perf numbers?
dlwh Feb 5, 2024
a3f9c7f
more logging
dlwh Feb 5, 2024
ce2db7b
moar
dlwh Feb 5, 2024
ea57bde
oops
dlwh Feb 5, 2024
d352b37
reduce logging some, try to figure out this stupid books problem
dlwh Feb 5, 2024
0effef0
ka dkla dkl
dlwh Feb 5, 2024
1e85d16
admaldl
dlwh Feb 5, 2024
a25a8ce
fix the unnecessarily long time outs
dlwh Feb 6, 2024
7c163a8
break really long docs into shorter docs b/c tokenizers is quadratic
dlwh Feb 6, 2024
4125d3f
kmklamdklad
dlwh Feb 6, 2024
99a87e8
maybe don't do the workaround so often?
dlwh Feb 6, 2024
5245e10
is this the leak?!?
dlwh Feb 6, 2024
8a6f59b
update for latest datasets
dlwh Feb 6, 2024
002989b
add a test to ensure we use the workaround for llama tokenizer
dlwh Feb 6, 2024
3dfebe2
tweak timeouts in test
dlwh Feb 6, 2024
5a5a1f1
less spammy logging
dlwh Feb 6, 2024
4e6df52
cleanup, see if we can avoid crashing when one cache finishes
dlwh Feb 6, 2024
c2dccf2
tweaks to tokenization/shard_cache throughput (#456)
dlwh Feb 6, 2024
f7a3d0a
Merge remote-tracking branch 'origin/main' into use_jamp_harfleur
dlwh Feb 6, 2024
c95ba8b
Merge branch 'use_jamp_harfleur' into use_jamp
dlwh Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
about got the checkpoint refactor done
  • Loading branch information
dlwh committed Nov 25, 2023
commit ec35e9bf5a4e920bd6fc2235efa693c83b00c826
131 changes: 65 additions & 66 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
@@ -56,13 +56,13 @@ class Checkpointer:
_last_temporary_checkpoint: Optional[str] = None

def __init__(
self,
base_path: PathLike,
save_interval: Optional[datetime.timedelta],
step_policies: Sequence[CheckpointInterval],
*,
keep_params: PyTree[FilterSpec] = True,
dt_now_injection: Optional[Callable[[], datetime.datetime]] = None,
self,
base_path: PathLike,
save_interval: Optional[datetime.timedelta],
step_policies: Sequence[CheckpointInterval],
*,
keep_params: PyTree[FilterSpec] = True,
dt_now_injection: Optional[Callable[[], datetime.datetime]] = None,
):
"""
Class for managing checkpoints. Saves checkpoints according to two policies: time and step.
@@ -102,38 +102,36 @@ def __init__(
raise ValueError("Step policies must be sorted by 'until' value")

def load_checkpoint(
self,
state: M,
path: Optional[PathLike] = None,
*,
discover_latest: bool = True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[haliax.partitioning.Mesh] = None,
) -> Optional[Tuple[M, int]]:
self,
state: M,
path: Optional[PathLike] = None,
*,
discover_latest: bool = True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[haliax.partitioning.Mesh] = None,
) -> Optional[M]:
if path is None:
path = self.base_path
return load_checkpoint(
state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh
)
return load_checkpoint(state, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh)

def load_model(
self,
model: M,
path: Optional[str] = None,
*,
discover_latest: bool = True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[haliax.partitioning.Mesh] = None,
) -> Optional[Tuple[M, int]]:
self,
model: M,
path: Optional[str] = None,
*,
discover_latest: bool = True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[haliax.partitioning.Mesh] = None,
) -> Optional[M]:
"""
Convenience method/holdover from previous API for loading checkpoints.
Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint.
Convenience method/holdover from previous API for loading checkpoints.
Loads just the model assuming the model is in the `model` subdir of the discovered checkpoint.
"""
ret_dict = self.load_checkpoint({"model": model},
path,
discover_latest=discover_latest,
axis_mapping=axis_mapping,
mesh=mesh)
ret_dict = self.load_checkpoint(
{"model": model}, path, discover_latest=discover_latest, axis_mapping=axis_mapping, mesh=mesh
)
if ret_dict is None:
return None
return ret_dict["model"]

def on_step(self, info, force: bool = False):
@@ -219,7 +217,7 @@ def _rm_checkpoint(self, checkpoint):
def save_checkpoint(self, info, destination: str):
path = os.path.join(self.base_path, destination)
logger.info(f"Saving checkpoint at step {info.step} to {path}")
state = equinox.partition(info.state, info.state.is_trainable)
state = equinox.filter(info.state, info.state.is_trainable)
save_checkpoint(
state,
step=info.step,
@@ -275,13 +273,13 @@ def save_metadata(checkpoint_path, fs, step):


def load_checkpoint(
tree: M,
checkpoint_path: PathLike,
*,
discover_latest=True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[jax.sharding.Mesh] = None,
) -> Optional[tuple[M, int]]:
tree: M,
checkpoint_path: PathLike,
*,
discover_latest=True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[jax.sharding.Mesh] = None,
) -> Optional[M]:
fs: AbstractFileSystem
fs, _ = _get_fs_and_plain_path(checkpoint_path)

@@ -298,8 +296,10 @@ def load_checkpoint(

try:
tree = tree_deserialize_leaves_tensorstore(checkpoint_path, tree, axis_mapping=axis_mapping, mesh=mesh)
return tree
except: # noqa
from levanter.trainer import TrainerState

if not isinstance(tree, TrainerState):
raise
else:
@@ -315,30 +315,29 @@ def load_checkpoint(
key = None
else:
training_state = tree_deserialize_leaves_tensorstore(
os.path.join(checkpoint_path, "training_state"), training_state, axis_mapping=axis_mapping, mesh=mesh
os.path.join(checkpoint_path, "training_state"),
training_state,
axis_mapping=axis_mapping,
mesh=mesh,
)
opt_state, key = training_state

# TODO: pretty sure this is right, but should verify
step = metadata["step"]
new_state = dataclasses.replace(
tree, # type: ignore
step=step + 1,
model=model,
opt_state=opt_state,
training_key=key)
return new_state, step

tree, step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore
)
return new_state


def _old_load_checkpoint(
model: M,
training_state: S,
checkpoint_path: PathLike,
*,
discover_latest=True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[jax.sharding.Mesh] = None,
model: M,
training_state: S,
checkpoint_path: PathLike,
*,
discover_latest=True,
axis_mapping: Optional[haliax.partitioning.ResourceMapping] = None,
mesh: Optional[jax.sharding.Mesh] = None,
) -> Optional[Tuple[M, S, int]]:
"""
Load a checkpoint from a given path.
@@ -422,10 +421,10 @@ def checkpoint_sort_key(ckpt_dir):


def tree_serialise_leaves(
path: PathLike,
pytree: PyTree,
filter_spec=default_serialise_filter_spec,
is_leaf: Optional[Callable[[Any], bool]] = None,
path: PathLike,
pytree: PyTree,
filter_spec=default_serialise_filter_spec,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> None:
"""Analog to `equinox.tree_serialise_leaves`, but saves the leaves of a PyTree using fsspec."""

@@ -443,11 +442,11 @@ def __serialise(y):


def tree_deserialise_leaves(
path: PathLike,
like: PyTree,
filter_spec=default_deserialise_filter_spec,
is_leaf: Optional[Callable[[Any], bool]] = None,
fs=None,
path: PathLike,
like: PyTree,
filter_spec=default_deserialise_filter_spec,
is_leaf: Optional[Callable[[Any], bool]] = None,
fs=None,
) -> PyTree:
"""
Analog to `equinox.tree_deserialise_leaves`, but loads the leaves of a PyTree using fsspec.
@@ -521,6 +520,6 @@ def __post_init__(self):
if prev_interval is not None:
assert prev_interval["until"] is not None, "Only the last checkpoint interval can be None"
assert (
interval["until"] is None or interval["until"] > prev_interval["until"]
interval["until"] is None or interval["until"] > prev_interval["until"]
), "Checkpoint intervals must be monotonic"
prev_interval = interval
128 changes: 64 additions & 64 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from dataclasses import dataclass, field
from typing import Optional

import equinox as eqx
import jax.random as jrandom

import haliax.random
@@ -66,7 +67,12 @@ def main(config: LoraLmConfig):
Pos = model_config.Pos
KeyPos = model_config.KeyPos

with config.trainer.device_mesh:
optimizer = config.optimizer.build(config.trainer.num_train_steps)

def compute_loss(model, example: LmExample, key=None):
return model.compute_loss(example, key=key).scalar()

with Trainer(config.trainer, optimizer, compute_loss) as trainer:
# how we shard parameters across devices
parameter_axis_mapping = config.trainer.parameter_axis_mapping

@@ -82,74 +88,68 @@ def loraize_hf_model(model):

lora_param_filter = lora_trainable_params_filter(model)

def compute_loss(model, example: LmExample, key=None):
return model.compute_loss(example, key=key).scalar()

optimizer = config.optimizer.build(config.trainer.num_train_steps)

# Our trainer is a wrapper around the optimizer and compute_loss function that handles checkpointing and fsdp
with Trainer(config.trainer, optimizer, compute_loss, is_trainable=lora_param_filter) as trainer:
eval_datasets = config.data.validation_sets(Pos.size)
eval_datasets = config.data.validation_sets(Pos.size)

state = trainer.initial_state(training_key, model=model, is_trainable=lora_param_filter)

all_param_count = parameter_count(state.model)
just_lora_params = parameter_count(eqx.filter(state.model, lora_param_filter))

levanter.tracker.log_summary(
{
"parameter_count": all_param_count,
"trainable_parameter_count": just_lora_params,
"fraction_trainable": just_lora_params * 1.0 / all_param_count,
}
)

logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")

# data loaders
if len(eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")

for name, eval_dataset in eval_datasets.items():
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos)
trainer.add_eval_hook(eval_dataset, name=name)

train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos)
train_loader = trainer.sharded_loader(train_dataset, Batch)

# boilerplate hooks and such
trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
if config.peft_save_path is not None:
full_save_path = os.path.join(config.peft_save_path, trainer.run_id)
trainer.add_hook(
save_peft_checkpoint_callback(
full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload
),
every=config.hf_save_steps,
)

if config.merged_hf_save_path is not None:
full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id)
trainer.add_hook(
save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload),
every=config.hf_save_steps,
)

state = trainer.initial_state(training_key, model=model)
# data loader. may need to seek to the right place if we're resuming
iter_data = non_caching_cycle(train_loader)

all_param_count = parameter_count(state.model)
just_lora_params = parameter_count(trainer._trainable_params_only(state.model))
if state.step > 0:
# step is after the batch, so we need to seek to step
# TODO: implement iter_data.seek(resume_step +1)
import tqdm

levanter.tracker.log_summary(
{
"parameter_count": all_param_count,
"trainable_parameter_count": just_lora_params,
"fraction_trainable": just_lora_params * 1.0 / all_param_count,
}
)
for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"):
next(iter_data)

logger.info(f"Total parameter count: {all_param_count}")
logger.info(f"Trainable parameter count: {just_lora_params}")
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count%.3}")

# data loaders
if len(eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")

for name, eval_dataset in eval_datasets.items():
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos)
trainer.add_eval_hook(eval_dataset, name=name)

train_dataset = CausalLmDataset(config.data.train_set(Pos.size), Pos, KeyPos)
train_loader = trainer.sharded_loader(train_dataset, Batch)

# boilerplate hooks and such
trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
if config.peft_save_path is not None:
full_save_path = os.path.join(config.peft_save_path, trainer.run_id)
trainer.add_hook(
save_peft_checkpoint_callback(
full_save_path, config.lora, config.initialize_from_hf, tokenizer, config.peft_hf_upload
),
every=config.hf_save_steps,
)

if config.merged_hf_save_path is not None:
full_save_path = os.path.join(config.merged_hf_save_path, trainer.run_id)
trainer.add_hook(
save_merged_hf_checkpoint_callback(full_save_path, converter, config.merged_hf_upload),
every=config.hf_save_steps,
)

# data loader. may need to seek to the right place if we're resuming
iter_data = non_caching_cycle(train_loader)

if state.step > 0:
# step is after the batch, so we need to seek to step
# TODO: implement iter_data.seek(resume_step +1)
import tqdm

for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"):
next(iter_data)

## OK, actually run training!
trainer.train(state, iter_data)
## OK, actually run training!
trainer.train(state, iter_data)


if __name__ == "__main__":
5 changes: 4 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from typing import Optional, Union

import jax.random as jrandom
import wandb

import haliax as hax
from haliax import Axis
@@ -181,12 +182,14 @@ def compute_log_probs(model, example: LmExample):
# TODO: implement iter_data.seek(resume_step +1)
import tqdm

for _ in tqdm.tqdm(range(state.step + 1), desc="seeking data for resume"):
for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"):
next(train_loader)

## OK, actually run training!
trainer.add_hook(lambda s: print(s.loss), every=20)
trainer.train(state, train_loader)
# checkpointer.on_step(last_step, force=True)
wandb.finish()


if __name__ == "__main__":
Loading