diff --git a/config/backpack.yaml b/config/backpack.yaml index 735d40c01..0fe93b539 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -12,7 +12,7 @@ model: trainer: tracker: project: "levanter" - tags: [ "openwebtext", "backpack" ] + tags: ["openwebtext", "backpack"] mp: p=f32,c=bfloat16 @@ -21,5 +21,5 @@ trainer: model_axis_size: 1 optimizer: - learning_rate: 6E-4 + learning_rate: 6e-4 weight_decay: 0.1 diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 135d10dd5..9ff544bd8 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -17,6 +17,7 @@ from tqdm_loggable import tqdm_logging from tqdm_loggable.auto import tqdm +import haliax as hax import haliax.nn from haliax import NamedArray, is_named_array from haliax.jax_utils import is_jax_array_like @@ -30,6 +31,8 @@ from levanter.utils import flop_utils, jax_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.utils.logging import save_xla_dumps_to_wandb +from levanter.utils.stat_utils import RunningMean +from levanter.utils.types import Extras from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -145,10 +148,8 @@ async def compute_length(): def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): - total_loss = 0.0 - total_load_time = 0.0 - total_loss_time = 0.0 - n = 0 + loss = RunningMean(jnp.zeros(()), jnp.zeros(())) + extras: Extras = {} if name is not None: desc = f"eval {name}" @@ -159,28 +160,27 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches) iter_ = iter(pbar) + n = 0 while True: - time_in = time.time() + n += 1 batch = next(iter_, None) if batch is None: break - load_time = time.time() - time_in - total_load_time += load_time - loss = loss_fn(model, batch) - total_loss += loss.item() - n += 1 - loss_time = time.time() - time_in - load_time - total_loss_time += loss_time + losses, where, extras = loss_fn(model, batch) + mean_loss = hax.mean(losses, where=where) + loss += RunningMean(mean_loss, where.sum()) + for k, v in extras.items(): + if k not in extras: + extras[k] = v + else: + extras[k] += v - pbar.set_postfix(loss=total_loss / n) + pbar.set_postfix(loss=loss.mean.item()) if max_batches is not None and n >= max_batches: break - if n > 0: - total_loss /= n - - return total_loss + return loss.item(), {k: v.item() for k, v in extras.items()} def compute_validation_loss( @@ -190,12 +190,14 @@ def compute_validation_loss( name: Optional[str] = None, ): def compute_loss(info: StepInfo): - loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) + loss, extras = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) prefix = "eval" if name: prefix += "/" + name - levanter.tracker.log({f"{prefix}/loss": loss}, step=info.step) + levanter.tracker.log( + {f"{prefix}/loss": loss} | {f"{prefix}/{k}": v for k, v in extras.items()}, step=info.step + ) if name: logger.info(f"{name} validation loss: {loss:.3f}") diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index cdcfe68cd..6cd14d496 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -134,12 +134,17 @@ def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) with hax.axis_mapping(trainer.compute_axis_mapping): # calculate per-token losses for proxy and ref - proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy) - ref_losses = loss_fn(ref, batch, reduction_axis=()) + def scalar_loss_fn(p, batch): + ret, _, _ = loss_fn(p, batch) + return ret + + proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: scalar_loss_fn(p, batch), proxy) + ref_losses = scalar_loss_fn(ref, batch) # calculate excess losses, aggregate per-domain losses excess_losses = proxy_losses - ref_losses clipped_losses = hax.maximum(excess_losses, 0) + print(clipped_losses.shape) per_domain_losses = _compute_per_domain_losses(clipped_losses, Domain, domains) # Update domain weights diff --git a/src/levanter/eval.py b/src/levanter/eval.py index ada22bc14..f58a9e2c8 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -301,8 +301,7 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, m = self.mp.cast_to_compute(m) with hax.axis_mapping(axis_mapping): - losses = compute_next_token_loss(m, batch, reduction=None, reduction_axis=()) - mask = batch.loss_mask # [Batch, Pos] + losses, mask, _extras = compute_next_token_loss(m, batch) this_tokens = hax.sum(mask) this_loss = hax.einsum("->", losses, mask) # to scalar diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index b9f4381fe..05ed209ff 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -37,7 +37,6 @@ import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer from levanter.models.gpt2 import Gpt2Config -from levanter.models.loss import next_token_loss from levanter.utils.hf_utils import HfTokenizer @@ -58,7 +57,7 @@ import levanter.config from levanter.checkpoint import load_checkpoint from levanter.data import AsyncDataset, DataLoader -from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel +from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss from levanter.trainer import StepInfo, TrainerConfig from levanter.utils.jax_utils import use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -157,15 +156,7 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr logits = logits.astype(jnp.float32) Pos = logits.resolve_axis(self.EvalPos.name) - loss = next_token_loss( - Pos=Pos, - Vocab=model.Vocab, - logits=logits, - true_ids=example.tokens, - loss_mask=example.loss_mask, - reduction=hax.sum, - reduction_axis=Pos, - ) + loss, _, _ = compute_next_token_loss(model, example) not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool) pred_targets = hax.argmax(logits, axis=model.Vocab) diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 4d167c562..3b77a75bd 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -1,4 +1,3 @@ -import enum import functools from typing import Callable, Optional, ParamSpec, TypeVar @@ -9,34 +8,31 @@ from jax.sharding import PartitionSpec import haliax as hax +import haliax.quantization as hq from haliax import Axis from haliax.partitioning import ResourceAxis from haliax.util import is_named_array from levanter.utils.jax_utils import zeros_like_tree +from levanter.utils.types import ComputeLossFunction Args = ParamSpec("Args") R = TypeVar("R") - - -class ReductionType(enum.Enum): - SUM = enum.auto() - MEAN = enum.auto() - # TODO: add MAX? +M_con = TypeVar("M_con", contravariant=True) # Model +X = TypeVar("X", contravariant=True) # Input # TODO: should we use a custom_jvp on microbatched? # cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 def microbatched( - fn: Callable[Args, R], + loss_fn: ComputeLossFunction[M_con, X], Batch: Axis, microbatch_size: int, accum_axis_mapping, compute_axis_mapping, patch_in_rng_key: Optional[str] = "key", - reduce: ReductionType = ReductionType.MEAN, accum_dtype: Optional[jnp.dtype] = None, ) -> Callable[Args, R]: """ @@ -78,20 +74,32 @@ def microbatched( num_micro_steps = batch_size // microbatch_size if num_micro_steps == 1: - return fn + + @functools.wraps(loss_fn) + def no_accum_loss_fn(*args, **kwargs): + losses, where, extras = loss_fn(*args, **kwargs) + seen_tokens = where.sum().scalar() + extras["seen_tokens"] = seen_tokens + return hax.mean(losses, where=where).scalar(), extras + + return eqx.filter_value_and_grad(no_accum_loss_fn, has_aux=True) Microbatch = Batch.resize(microbatch_size) AccumStep = Axis("accum_step", num_micro_steps) assert num_micro_steps * microbatch_size == batch_size - if reduce not in ReductionType: - raise ValueError(f"accum_type must be one of {ReductionType}") + @functools.wraps(loss_fn) + def accum_loss_fn(*args, **kwargs): + losses, where, extras = loss_fn(*args, **kwargs) + return hax.sum(losses, where=where).scalar(), (where.sum(), extras) - @functools.wraps(fn) + grad_fn = eqx.filter_value_and_grad(accum_loss_fn, has_aux=True) + + @functools.wraps(grad_fn) def wrapped_fn(*args, **kwargs): # first, determine the shape and make accumulator arrays - r_shape = eqx.filter_eval_shape(fn, *args, **kwargs) + r_shape = eqx.filter_eval_shape(grad_fn, *args, **kwargs) acc = zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype) # then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...) @@ -106,30 +114,34 @@ def wrapped_fn(*args, **kwargs): args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping) def loop(acc, microbatch_and_key): + (loss, (total, extras)), grads = acc microbatch, microbatch_kwargs, key = microbatch_and_key with jax.named_scope("compute"): microbatch_kwargs = microbatch_kwargs.copy() if key is not None: microbatch_kwargs[patch_in_rng_key] = key - this_r = fn(*microbatch, **microbatch_kwargs) + (loss_mb, (n_mb, extras_mb)), grads_mb = grad_fn(*microbatch, **microbatch_kwargs) with jax.named_scope("accum"): - import haliax.quantization as hq # TODO: this uses the latest value for the scale for fp8, which seems not ideal but probably ok? - overwrites, updates = hq.partition_for_grad_overwrite(this_r) - acc = hq.apply_updates(acc, updates, overwrites) - acc = hax.shard_with_axis_mapping(acc, accum_axis_mapping) + overwrites, updates = hq.partition_for_grad_overwrite(grads_mb) + grads = hq.apply_updates(grads, updates, overwrites) + grads = hax.shard_with_axis_mapping(grads, accum_axis_mapping) + loss += loss_mb + total += n_mb - return acc + return (loss, (total, {k: v + extras_mb[k] for k, v in extras.items()})), grads with jax.named_scope("microbatched"): - acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key)) - - if reduce == ReductionType.MEAN: - acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc) - - return acc + (loss, (total, extras)), grads, = hax.fold( + loop, AccumStep + )(acc, (args, kwargs, key)) + grads = jax.tree_util.tree_map(lambda x: x / total, grads) + loss /= total + extras["seen_tokens"] = total + + return (loss, extras), grads return wrapped_fn diff --git a/src/levanter/main/train_asr.py b/src/levanter/main/train_asr.py index c0d52eea2..0f246a302 100644 --- a/src/levanter/main/train_asr.py +++ b/src/levanter/main/train_asr.py @@ -88,10 +88,8 @@ def compute_loss( example: AudioTextExample, *, key=None, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, ) -> jax.numpy.ndarray | hax.NamedArray: - return m.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis) + return m.compute_loss(example, key=key) # Using the trainer as a context manager does 3 things: # 1. Sets the device mesh diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index e0aa1596d..682a29f6d 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -74,7 +74,7 @@ def main(config: VizGpt2Config): def compute_log_probs(model: LmHeadModel, example: LmExample): model = inference_mode(model, True) model = mp.cast_to_compute(model) - logprobs = compute_next_token_loss(model, example, reduction=None) + logprobs, where, _ = compute_next_token_loss(model, example) # roll forward to get the loss for each predicted token logprobs = hax.roll(logprobs, 1, Pos) return logprobs.rearrange((EvalBatch, Pos)).array diff --git a/src/levanter/models/asr_model.py b/src/levanter/models/asr_model.py index 9955dbfa5..c21acaba2 100644 --- a/src/levanter/models/asr_model.py +++ b/src/levanter/models/asr_model.py @@ -11,6 +11,7 @@ from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmConfig +from levanter.utils.types import Extras class AudioTextExample(eqx.Module): @@ -97,9 +98,7 @@ def compute_loss( example: AudioTextExample, *, key=None, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, - ) -> jnp.ndarray | NamedArray: + ) -> tuple[jnp.ndarray | NamedArray, NamedArray, Extras]: """ Computes the cross-entropy loss for predicted ASR tokens. If reduction is not None, the loss is reduced across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not @@ -110,10 +109,13 @@ def compute_loss( targets = hax.roll(example.tokens, -1, axis=self.Pos.name) target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) loss = cross_entropy_loss( - logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask + logits, + self.Vocab, + target_y, + reduction=None, ) - return loss + return loss, example.loss_mask, {} @property def vocab_size(self) -> int: diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 7f5c0e3d8..52f895593 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -88,6 +88,9 @@ def from_prompt_and_completion( return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + def num_elements(self): + return self.loss_mask.sum() + # TODO: for some reason, mypy doesn't like the discover_packages_path argument? @dataclass(frozen=True) @@ -219,19 +222,21 @@ def compute_next_token_loss( example: LmExample, *, key=None, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, logsumexp_weight: Optional[float] = None, loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32, -) -> jnp.ndarray | NamedArray: +) -> tuple[NamedArray, NamedArray, dict]: """ Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ activations = model.activations(example.tokens, example.attn_mask, key=key) + if isinstance(activations, tuple): + activations, extras = activations + else: + extras = {} - loss = maybe_fused_next_token_loss( + loss, where = maybe_fused_next_token_loss( model.Pos, model.Embed, model.Vocab, @@ -239,11 +244,9 @@ def compute_next_token_loss( model.get_lm_head(), example.tokens, loss_mask=example.loss_mask, - reduction=reduction, - reduction_axis=reduction_axis, logsumexp_weight=logsumexp_weight, dtype=loss_dtype, block_size=model.config.cross_entropy_block_size, ) - return loss + return loss, where, extras diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index bf0bd380e..bf84b0ba7 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -1,4 +1,5 @@ import functools +import logging from typing import Optional import equinox @@ -10,6 +11,9 @@ from haliax.nn import cross_entropy_loss_and_log_normalizers +logger = logging.getLogger(__name__) + + def maybe_fused_next_token_loss( Pos: hax.AxisSelector, Embed: hax.AxisSelector, @@ -18,12 +22,10 @@ def maybe_fused_next_token_loss( pred_lm_head: NamedArray, true_ids: NamedArray, loss_mask: Optional[NamedArray] = None, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, logsumexp_weight: Optional[float] = None, block_size: Optional[int] = None, dtype: Optional[jnp.dtype] = jnp.float32, -) -> NamedArray: +) -> tuple[NamedArray, NamedArray]: """ Compute the next token loss with optional block-wise processing. @@ -36,6 +38,7 @@ def maybe_fused_next_token_loss( loss_mask (Optional[NamedArray]): Mask to apply to the loss. reduction (Optional[hax.ReductionFunction]): Reduction function. reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction. + batch_num_elements (Optional[int]): The number of elements in the batch. When passed, it is used to reduce the loss. logsumexp_weight (Optional[float]): Weight for logsumexp penalty. block_size (Optional[int]): Size of each block for processing. @@ -45,6 +48,13 @@ def maybe_fused_next_token_loss( # Resolve axes Pos = pred_embeddings.resolve_axis(Pos) Vocab = pred_lm_head.resolve_axis(Vocab) + not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore + if loss_mask is not None: + loss_mask = loss_mask * not_last_loss_mask + else: + loss_mask = not_last_loss_mask + + target_y = hax.roll(true_ids, -1, Pos) if block_size is None: # Full softmax computation @@ -53,42 +63,29 @@ def maybe_fused_next_token_loss( logits = logits.astype(dtype) # Shift target tokens to predict the next token - return next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight) - - # Shift target tokens to predict the next token - target_y = hax.roll(true_ids, -1, Pos) - - # Create a mask that excludes the last token - not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore - if loss_mask is not None: - loss_mask = loss_mask * not_last_loss_mask + return next_token_loss(Pos, Vocab, logits, target_y, logsumexp_weight), loss_mask else: - loss_mask = not_last_loss_mask - - # Compute the loss with optional block-wise processing - return fused_cross_entropy_loss_and_logsumexp_penalty( - pred_embeddings, - pred_lm_head, - Contract=Embed, - Label=Vocab, - target_y=target_y, - reduction=reduction, - reduction_axis=reduction_axis, - where=loss_mask, - logsumexp_weight=logsumexp_weight, - block_size=block_size, - dtype=dtype, - ) + # Compute the loss with optional block-wise processing + return ( + fused_cross_entropy_loss_and_logsumexp_penalty( + pred_embeddings, + pred_lm_head, + Contract=Embed, + Label=Vocab, + target_y=target_y, + logsumexp_weight=logsumexp_weight, + block_size=block_size, + dtype=dtype, + ), + loss_mask, + ) def next_token_loss( Pos: hax.AxisSelector, Vocab: hax.AxisSelector, logits: NamedArray, - true_ids: NamedArray, - loss_mask: Optional[NamedArray] = None, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, + target_y: NamedArray, logsumexp_weight: Optional[float] = None, ): """ @@ -108,23 +105,12 @@ def next_token_loss( """ Pos = logits.resolve_axis(Pos) - target_y = hax.roll(true_ids, -1, Pos) target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=logits.dtype) - # Create a mask that excludes the last token - not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore - if loss_mask is not None: - loss_mask = loss_mask * not_last_loss_mask - else: - loss_mask = not_last_loss_mask - return cross_entropy_and_logsumexp_penalty( Vocab=Vocab, pred_y=logits, target_y=target_y_full, - reduction=reduction, - reduction_axis=reduction_axis, - where=loss_mask, logsumexp_weight=logsumexp_weight, ) @@ -134,9 +120,6 @@ def cross_entropy_and_logsumexp_penalty( pred_y: NamedArray, target_y: NamedArray, *, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, - where: Optional[NamedArray] = None, logsumexp_weight=0.0, ) -> NamedArray: """A loss function that combines cross entropy loss with a logsumexp penalty.""" @@ -146,7 +129,7 @@ def cross_entropy_and_logsumexp_penalty( if logsumexp_weight is not None and logsumexp_weight != 0.0: loss = loss + logsumexp_weight * (log_normalizers**2) - return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where) + return loss def fused_cross_entropy_loss_and_logsumexp_penalty( @@ -156,9 +139,6 @@ def fused_cross_entropy_loss_and_logsumexp_penalty( Label: hax.AxisSelector, target_y: NamedArray, *, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, - where: Optional[NamedArray] = None, logsumexp_weight: float | None = 0.0, block_size: int, dtype: Optional[jnp.dtype] = jnp.float32, @@ -192,7 +172,7 @@ def fused_cross_entropy_loss_and_logsumexp_penalty( if logsumexp_weight is not None and (not isinstance(logsumexp_weight, (int, float)) or logsumexp_weight != 0.0): loss = loss + logsumexp_weight * (log_normalizers**2) - return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where) + return loss @equinox.filter_custom_vjp diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 82f32422a..cbb272dcd 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -43,7 +43,7 @@ from levanter.utils import cloud_utils, fsspec_utils from levanter.utils.jax_utils import create_fsdp_mesh, zeros_like_tree from levanter.utils.tree_utils import inference_mode -from levanter.utils.types import ComputeLossFunction, FilterSpec +from levanter.utils.types import ComputeLossFunction, Extras, FilterSpec logger = pylogging.getLogger(__name__) @@ -380,7 +380,7 @@ def checkpoint_path(self) -> str: checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) return checkpoint_path - def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]: + def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]: """ Performs a single training step. """ @@ -391,10 +391,10 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]: with capture_time() as step_time: if hooks_this_time: - loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs) + loss, new_state, extras, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) else: - loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) + loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) loss = loss.item() # type: ignore info = StepInfo(new_state, loss, step_time()) @@ -404,7 +404,8 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]: if hooks_this_time: self.hooks.run_jit_hooks_outside_step(info, cb_states) - levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step) + log_items = {k: v.item() for k, v in extras.items()} | {"throughput/hook_time": hook_time()} + levanter.tracker.log(log_items, step=info.step) return info @@ -525,11 +526,13 @@ def _jit_train_step_fn_no_hook(self): def _train_step( self, state: S, batch, batch_kwargs, _no_hooks=False - ) -> tuple[Scalar, S, Sequence[CBInfo]] | tuple[Scalar, S]: + ) -> tuple[Scalar, S, Extras, Sequence[CBInfo]] | tuple[Scalar, S, Extras]: key, new_key = jax.random.split(state.training_key) model = inference_mode(state.model, False) - loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key) + (loss, extras), grads = self._compute_gradients_microbatched( + self.loss_fn, model, batch, **batch_kwargs, key=key + ) with hax.axis_mapping(self.parameter_axis_mapping): if not _no_hooks: @@ -545,22 +548,23 @@ def obj_fun(trainable_model): new_state = state.take_step(grads, obj_fun=obj_fun) new_state = hax.shard(new_state, self.parameter_axis_mapping) if _no_hooks: - return loss, new_state + return loss, new_state, extras else: - return loss, new_state, hook_infos + return loss, new_state, extras, hook_infos - def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]: - grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) + def _compute_gradients_microbatched( + self, loss_fn, model: M, batch: X, **batch_kwargs + ) -> tuple[tuple[Scalar, Extras], M]: mbs = self.config.microbatch_size grad_fn = microbatched( - grad_fn, + loss_fn, self.TrainBatch, mbs, self.parameter_axis_mapping, self.compute_axis_mapping, - ) + ) # type: ignore with hax.axis_mapping(self.compute_axis_mapping): - return grad_fn(model, *batch, **batch_kwargs) + return grad_fn(model, batch, **batch_kwargs) def _initialize_global_tracker(config, run_id): diff --git a/src/levanter/utils/stat_utils.py b/src/levanter/utils/stat_utils.py index 6111be42e..895003cf4 100644 --- a/src/levanter/utils/stat_utils.py +++ b/src/levanter/utils/stat_utils.py @@ -1,13 +1,26 @@ -import typing +from typing import TypeAlias import equinox as eqx import jax.numpy as jnp import numpy as np +from typing_extensions import Self import haliax as hax +from levanter.utils.types import Accumulatable -Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray + +Arrayish: TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray + + +class SumScalar(Accumulatable): + value: jnp.ndarray + + def item(self) -> float: + return self.value.item() + + def __add__(self, other: Self) -> Self: + return SumScalar(self.value + other.value) class RunningMean(eqx.Module): @@ -27,6 +40,9 @@ def add(self, x: Arrayish, total: Arrayish) -> "RunningMean": new_total = self.total + total return RunningMean(new_mean, new_total) + def item(self) -> float: + return self.mean.item() + def __add__(self, other: "RunningMean"): return self.add(other.mean, other.total) diff --git a/src/levanter/utils/types.py b/src/levanter/utils/types.py index 46ccac2b5..dafc0791d 100644 --- a/src/levanter/utils/types.py +++ b/src/levanter/utils/types.py @@ -1,6 +1,10 @@ -from typing import Any, Callable, Optional, Protocol, Tuple, TypeVar, Union +import abc +from typing import Any, Callable, Dict, Protocol, Tuple, TypeAlias, TypeVar, Union +import equinox as eqx +import jax from jaxtyping import PyTree +from typing_extensions import Self import haliax as hax from haliax.types import Scalar @@ -10,6 +14,19 @@ M_con = TypeVar("M_con", contravariant=True) # Model X = TypeVar("X", contravariant=True) # Input + +class Accumulatable(abc.ABC, eqx.Module): + @abc.abstractmethod + def item(self) -> float: + pass + + @abc.abstractmethod + def __add__(self, other: Self) -> Self: + pass + + +Extras: TypeAlias = Dict[str, jax.Array | Accumulatable] + try: from haliax.nn.scan import BlockFoldable except ImportError: @@ -51,9 +68,7 @@ class ComputeLossFunction(Protocol[M_con, X]): def __call__( self, model: M_con, - *inputs: X, - reduction: Optional[hax.ReductionFunction] = hax.mean, - reduction_axis: Optional[hax.AxisSelection] = None, + input: X, **kwargs, - ) -> Scalar | hax.NamedArray: + ) -> tuple[hax.NamedArray, hax.NamedArray, Extras]: ... diff --git a/tests/test_doremi.py b/tests/test_doremi.py index 3ad4aa9ab..e60554480 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -128,10 +128,11 @@ def test_estimate_mixture_weights(): ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys)) # TODO: remove key as a requirement for models - def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None): + def compute_loss_fn(model, example, key=None): del key y_pred = model(example.x) - return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis) + losses = hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=None) + return losses, hax.ones_like(losses), {} tiny_trainer_config = TrainerConfig( num_train_steps=300, diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py index a6e96f2d8..77ef4df7e 100644 --- a/tests/test_grad_accum.py +++ b/tests/test_grad_accum.py @@ -45,7 +45,13 @@ def test_accumulate_gradients_sharded(parallelism, accum_steps): mlp = Mlp.init(In, Out, Mid, key=jax.random.PRNGKey(0)) def loss_fn(mlp, x): - return mlp(x).mean().scalar() + out = mlp(x) + where = hax.ones_like(out) + return out, where, {} + + def scalar_loss_fn(mlp, x): + out, where, _ = loss_fn(mlp, x) + return hax.mean(out, where=where).scalar(), {} x = hax.random.normal(jax.random.PRNGKey(0), (Batch, In)) @@ -57,20 +63,15 @@ def loss_fn(mlp, x): @hax.partitioning.named_jit(axis_resources=axis_mapping) def jit_grad_accum(mlp, x): - grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) - grad_fn = microbatched(grad_fn, Batch, parallelism, axis_mapping, axis_mapping) - acc_v, acc_g = grad_fn( - mlp, - x, - ) - return acc_v, acc_g + grad_fn = microbatched(loss_fn, Batch, parallelism, axis_mapping, axis_mapping) + return grad_fn(mlp, x) with mesh: mlp = haliax.shard(mlp, axis_mapping) x = haliax.shard(x, axis_mapping) - grad_fn = eqx.filter_value_and_grad(loss_fn) - acc_v, acc_g = jit_grad_accum(mlp, x) - v, g = grad_fn(mlp, x) + grad_fn = eqx.filter_value_and_grad(scalar_loss_fn, has_aux=True) + (acc_v, _), acc_g = jit_grad_accum(mlp, x) + (v, _), g = grad_fn(mlp, x) assert_trees_all_close(acc_v, v, atol=1e-3, rtol=1e-3) diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index 9cc46ca0d..29074504c 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -135,7 +135,8 @@ def torch_loss(model, input_ids) -> torch.Tensor: def compute_loss(model: LmHeadModel, input_ids): example = LmExample.causal(input_ids, eos_id=converter.tokenizer.eos_token_id) - return compute_next_token_loss(model, example, key=None).scalar() + loss, where, _ = compute_next_token_loss(model, example, key=None) + return hax.mean(loss, where=where).scalar() jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False) jax_grad: Gpt2LMHeadModel diff --git a/tests/test_text.py b/tests/test_text.py index 63d0afedb..046b7a7d2 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -40,12 +40,14 @@ def test_lm_example_handles_ignore_id(): lm_head = hax.zeros((Embed, Vocab)) lm_head = lm_head.at[Vocab, ignore_id].set(-100) - ignored_loss = maybe_fused_next_token_loss( + ignored_loss, ignored_where = maybe_fused_next_token_loss( Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask ) - no_ignore_loss = maybe_fused_next_token_loss( + ignored_loss = hax.sum(ignored_loss, where=ignored_where) + no_ignore_loss, no_ignore_where = maybe_fused_next_token_loss( Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask ) + no_ignore_loss = hax.sum(no_ignore_loss, where=no_ignore_where) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size