Skip to content

Commit

Permalink
substatial changes to save on epochs w callback
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 24, 2024
1 parent e82eec2 commit 08fd427
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 9 deletions.
13 changes: 6 additions & 7 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.visualization import compute_and_visualize_log_probs as viz_probs


logger = pylogging.getLogger(__name__)


def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size):
def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size, max_epochs: Optional[int] = None):
total_tokens = None

def log_epoch(step_info: StepInfo):
Expand All @@ -45,10 +43,11 @@ def log_epoch(step_info: StepInfo):

# Get the total processed tokens from the metrics logged by log_performance_stats
processed_tokens = tokens_per_example * batch_size * step_info.step
if processed_tokens is None:
return # No token count available yet

current_epoch = processed_tokens / total_tokens

# If we're doing multiple epochs, adjust the denominator
total_tokens_for_epochs = total_tokens * max_epochs if max_epochs else total_tokens
current_epoch = processed_tokens / total_tokens_for_epochs

levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step)

return log_epoch
Expand Down
34 changes: 34 additions & 0 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore
from levanter.types import FilterSpec
# from levanter.trainer import StepInfo


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -261,6 +262,39 @@ def _async_checkpoint_remover(self):
self._do_rm_checkpoint(checkpoint)
self._checkpoint_being_removed = None

# In callbacks.py - Add a new callback that handles epoch checkpointing
class EpochCheckpointer:
"""
A separate checkpointing system that saves based on epochs.
Works alongside the regular step-based checkpointer without modifying core state.
"""
def __init__(self,
checkpointer: Checkpointer,
every_n_epochs: int = 1,
total_dataset_size: Optional[int] = None,
batch_size: int = 1):
self.checkpointer = checkpointer
self.every_n_epochs = every_n_epochs
self.total_dataset_size = total_dataset_size
self.batch_size = batch_size
self._last_saved_epoch = -1

def __call__(self, step_info):
if self.total_dataset_size is None:
return # Can't calculate epochs without dataset size

# Calculate current epoch from steps without modifying StepInfo
current_epoch = (step_info.step * self.batch_size) // self.total_dataset_size

# Only save if we've moved to a new epoch and it matches our interval
if (current_epoch > self._last_saved_epoch and
current_epoch % self.every_n_epochs == 0):
# Use existing checkpointer's save_checkpoint method
self.checkpointer.save_checkpoint(
step_info,
f"epoch-{current_epoch}"
)
self._last_saved_epoch = current_epoch

def save_checkpoint(
tree: M,
Expand Down
13 changes: 11 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import levanter
from levanter import callbacks
from levanter.checkpoint import load_checkpoint
from levanter.checkpoint import EpochCheckpointer, load_checkpoint
from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig
from levanter.models.gpt2 import Gpt2Config
Expand Down Expand Up @@ -132,9 +132,18 @@ def main(config: TrainLmConfig):
if config.epoch > 0:
total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len)
trainer.add_hook(
callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1
callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size, max_epochs=config.epoch), every=1
)

# Add epoch checkpoint callback
epoch_checkpointer = EpochCheckpointer(
checkpointer=trainer.config.checkpointer.create(trainer.run_id),
every_n_epochs=1, # Or configure as needed
total_dataset_size=total_tokens_future.result(),
batch_size=trainer.config.train_batch_size
)
trainer.add_hook(epoch_checkpointer, every=1)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
Expand Down

0 comments on commit 08fd427

Please sign in to comment.