Skip to content

Commit

Permalink
fix serialization of counts and such in optimzer states, remove non-t…
Browse files Browse the repository at this point in the history
…rainables from model averaging
  • Loading branch information
dlwh committed Jan 23, 2025
1 parent 237851b commit aa25122
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 45 deletions.
43 changes: 5 additions & 38 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,49 +392,16 @@ def load_checkpoint(
checkpoint_path = discovered_checkpoint_path

logger.info(f"Loading checkpoint from {checkpoint_path}")
metadata = load_metadata(checkpoint_path, fs)

if subpath:
checkpoint_path = os.path.join(checkpoint_path, subpath)

ser, non_ser = equinox.partition(tree, is_jax_array_like)
try:
tree = tree_deserialize_leaves_tensorstore(
checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh, allow_missing=allow_partial
)
tree = equinox.combine(tree, non_ser)
return tree
except: # noqa
from levanter.trainer_state import TrainerState

if not isinstance(tree, TrainerState):
raise
else:
logger.warning("Attempting to load old-style checkpoint")
model, training_state = tree.model, (tree.opt_state, tree.training_key)

model = tree_deserialize_leaves_tensorstore(
os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh
)

if training_state is None:
opt_state = None
key = None
else:
training_state = tree_deserialize_leaves_tensorstore(
os.path.join(checkpoint_path, "training_state"),
training_state,
axis_mapping=axis_mapping,
mesh=mesh,
)
opt_state, key = training_state

# TODO: pretty sure this is right, but should verify
step = metadata["step"]
new_state = dataclasses.replace(
tree, step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore
)
return new_state
tree = tree_deserialize_leaves_tensorstore(
checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh, allow_missing=allow_partial
)
tree = equinox.combine(tree, non_ser)
return tree


def load_checkpoint_or_initialize(
Expand Down
18 changes: 11 additions & 7 deletions src/levanter/trainer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def eval_model(self) -> M:
Otherwise, it uses the inference mode of the model.
"""
if self.model_averaging is not None:
# model averaging only gets the trainable params so we have to patch in the trainables
m = self.model_averaging.model_params
m = eqx.combine(m, self.model)
else:
m = self.model

Expand All @@ -108,10 +110,13 @@ def init(
if fp8 is not None:
model = fp8_linear_layers(model, fp8)

trainable_model = trainables_only(model, is_trainable)

if model_averaging is not None:
model_averaging = model_averaging.create(model)
model_averaging = model_averaging.create(trainable_model)

opt_state = init_optimizer_for_trainables(optimizer, model)

opt_state = init_optimizer_for_trainables(optimizer, model, is_trainable)
return cls(
0,
model,
Expand All @@ -137,19 +142,18 @@ def take_step(self: S, grads: PyTree, obj_fun: Optional[Callable[[M], Scalar]] =
)

if self.model_averaging is not None:
ma = self.model_averaging.update(model, self.step)
ma = self.model_averaging.update(trainables_only(model, self.is_trainable), self.step)
else:
ma = None

return dataclasses.replace(self, model=model, opt_state=opt_state, model_averaging=ma, step=self.step + 1)


def init_optimizer_for_trainables(optimizer, model, is_trainable):
def init_optimizer_for_trainables(optimizer, trainable_model):
"""
Initializes the optimizer state for the trainable parameters of the model.
"""
trainable = trainables_only(model, is_trainable)
_, trainable = partition_for_grad_overwrite(trainable) # doesn't make a huge difference, but saves some ram
_, trainable = partition_for_grad_overwrite(trainable_model) # doesn't make a huge difference, but saves some ram
opt_state = optimizer.init(trainable)
return opt_state

Expand Down Expand Up @@ -202,7 +206,7 @@ def saveable_training_mask(trainer_state: S, is_trainable_param: FilterTree = Tr

is_trainable_param = make_floating_point_trainable_filter(is_trainable_param)

trainer_state = jax.tree_util.tree_map(lambda x: is_inexact_arrayish, trainer_state)
trainer_state = jax.tree_util.tree_map(lambda x: True, trainer_state)
saveable_state = dataclasses.replace(trainer_state, step=True, training_key=True) # type: ignore
saveable_state = dataclasses.replace(saveable_state, model=is_trainable_param) # type: ignore
return saveable_state # type: ignore
Expand Down

0 comments on commit aa25122

Please sign in to comment.