diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index b9eabe922..f35e2d7a4 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -392,49 +392,16 @@ def load_checkpoint( checkpoint_path = discovered_checkpoint_path logger.info(f"Loading checkpoint from {checkpoint_path}") - metadata = load_metadata(checkpoint_path, fs) if subpath: checkpoint_path = os.path.join(checkpoint_path, subpath) ser, non_ser = equinox.partition(tree, is_jax_array_like) - try: - tree = tree_deserialize_leaves_tensorstore( - checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh, allow_missing=allow_partial - ) - tree = equinox.combine(tree, non_ser) - return tree - except: # noqa - from levanter.trainer_state import TrainerState - - if not isinstance(tree, TrainerState): - raise - else: - logger.warning("Attempting to load old-style checkpoint") - model, training_state = tree.model, (tree.opt_state, tree.training_key) - - model = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "model"), model, axis_mapping=axis_mapping, mesh=mesh - ) - - if training_state is None: - opt_state = None - key = None - else: - training_state = tree_deserialize_leaves_tensorstore( - os.path.join(checkpoint_path, "training_state"), - training_state, - axis_mapping=axis_mapping, - mesh=mesh, - ) - opt_state, key = training_state - - # TODO: pretty sure this is right, but should verify - step = metadata["step"] - new_state = dataclasses.replace( - tree, step=step + 1, model=model, opt_state=opt_state, training_key=key # type: ignore - ) - return new_state + tree = tree_deserialize_leaves_tensorstore( + checkpoint_path, ser, axis_mapping=axis_mapping, mesh=mesh, allow_missing=allow_partial + ) + tree = equinox.combine(tree, non_ser) + return tree def load_checkpoint_or_initialize( diff --git a/src/levanter/trainer_state.py b/src/levanter/trainer_state.py index d0a8f9858..46ef501a1 100644 --- a/src/levanter/trainer_state.py +++ b/src/levanter/trainer_state.py @@ -81,7 +81,9 @@ def eval_model(self) -> M: Otherwise, it uses the inference mode of the model. """ if self.model_averaging is not None: + # model averaging only gets the trainable params so we have to patch in the trainables m = self.model_averaging.model_params + m = eqx.combine(m, self.model) else: m = self.model @@ -108,10 +110,13 @@ def init( if fp8 is not None: model = fp8_linear_layers(model, fp8) + trainable_model = trainables_only(model, is_trainable) + if model_averaging is not None: - model_averaging = model_averaging.create(model) + model_averaging = model_averaging.create(trainable_model) + + opt_state = init_optimizer_for_trainables(optimizer, model) - opt_state = init_optimizer_for_trainables(optimizer, model, is_trainable) return cls( 0, model, @@ -137,19 +142,18 @@ def take_step(self: S, grads: PyTree, obj_fun: Optional[Callable[[M], Scalar]] = ) if self.model_averaging is not None: - ma = self.model_averaging.update(model, self.step) + ma = self.model_averaging.update(trainables_only(model, self.is_trainable), self.step) else: ma = None return dataclasses.replace(self, model=model, opt_state=opt_state, model_averaging=ma, step=self.step + 1) -def init_optimizer_for_trainables(optimizer, model, is_trainable): +def init_optimizer_for_trainables(optimizer, trainable_model): """ Initializes the optimizer state for the trainable parameters of the model. """ - trainable = trainables_only(model, is_trainable) - _, trainable = partition_for_grad_overwrite(trainable) # doesn't make a huge difference, but saves some ram + _, trainable = partition_for_grad_overwrite(trainable_model) # doesn't make a huge difference, but saves some ram opt_state = optimizer.init(trainable) return opt_state @@ -202,7 +206,7 @@ def saveable_training_mask(trainer_state: S, is_trainable_param: FilterTree = Tr is_trainable_param = make_floating_point_trainable_filter(is_trainable_param) - trainer_state = jax.tree_util.tree_map(lambda x: is_inexact_arrayish, trainer_state) + trainer_state = jax.tree_util.tree_map(lambda x: True, trainer_state) saveable_state = dataclasses.replace(trainer_state, step=True, training_key=True) # type: ignore saveable_state = dataclasses.replace(saveable_state, model=is_trainable_param) # type: ignore return saveable_state # type: ignore