diff --git a/config/backpack.yaml b/config/backpack.yaml
index 735d40c01..0fe93b539 100644
--- a/config/backpack.yaml
+++ b/config/backpack.yaml
@@ -12,7 +12,7 @@ model:
 trainer:
   tracker:
     project: "levanter"
-    tags: [ "openwebtext", "backpack" ]
+    tags: ["openwebtext", "backpack"]
 
   mp: p=f32,c=bfloat16
 
@@ -21,5 +21,5 @@ trainer:
   model_axis_size: 1
 
 optimizer:
-  learning_rate: 6E-4
+  learning_rate: 6e-4
   weight_decay: 0.1
diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py
index 135d10dd5..9ff544bd8 100644
--- a/src/levanter/callbacks.py
+++ b/src/levanter/callbacks.py
@@ -17,6 +17,7 @@
 from tqdm_loggable import tqdm_logging
 from tqdm_loggable.auto import tqdm
 
+import haliax as hax
 import haliax.nn
 from haliax import NamedArray, is_named_array
 from haliax.jax_utils import is_jax_array_like
@@ -30,6 +31,8 @@
 from levanter.utils import flop_utils, jax_utils
 from levanter.utils.jax_utils import barrier_sync, jnp_to_python
 from levanter.utils.logging import save_xla_dumps_to_wandb
+from levanter.utils.stat_utils import RunningMean
+from levanter.utils.types import Extras
 from levanter.visualization import compute_and_visualize_log_probs as viz_probs
 
 
@@ -145,10 +148,8 @@ async def compute_length():
 
 
 def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
-    total_loss = 0.0
-    total_load_time = 0.0
-    total_loss_time = 0.0
-    n = 0
+    loss = RunningMean(jnp.zeros(()), jnp.zeros(()))
+    extras: Extras = {}
 
     if name is not None:
         desc = f"eval {name}"
@@ -159,28 +160,27 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n
     pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches)
 
     iter_ = iter(pbar)
+    n = 0
     while True:
-        time_in = time.time()
+        n += 1
         batch = next(iter_, None)
         if batch is None:
             break
-        load_time = time.time() - time_in
-        total_load_time += load_time
-        loss = loss_fn(model, batch)
-        total_loss += loss.item()
-        n += 1
-        loss_time = time.time() - time_in - load_time
-        total_loss_time += loss_time
+        losses, where, extras = loss_fn(model, batch)
+        mean_loss = hax.mean(losses, where=where)
+        loss += RunningMean(mean_loss, where.sum())
+        for k, v in extras.items():
+            if k not in extras:
+                extras[k] = v
+            else:
+                extras[k] += v
 
-        pbar.set_postfix(loss=total_loss / n)
+        pbar.set_postfix(loss=loss.mean.item())
 
         if max_batches is not None and n >= max_batches:
             break
 
-    if n > 0:
-        total_loss /= n
-
-    return total_loss
+    return loss.item(), {k: v.item() for k, v in extras.items()}
 
 
 def compute_validation_loss(
@@ -190,12 +190,14 @@ def compute_validation_loss(
     name: Optional[str] = None,
 ):
     def compute_loss(info: StepInfo):
-        loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name)
+        loss, extras = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name)
 
         prefix = "eval"
         if name:
             prefix += "/" + name
-        levanter.tracker.log({f"{prefix}/loss": loss}, step=info.step)
+        levanter.tracker.log(
+            {f"{prefix}/loss": loss} | {f"{prefix}/{k}": v for k, v in extras.items()}, step=info.step
+        )
 
         if name:
             logger.info(f"{name} validation loss: {loss:.3f}")
diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py
index cdcfe68cd..6cd14d496 100644
--- a/src/levanter/doremi.py
+++ b/src/levanter/doremi.py
@@ -134,12 +134,17 @@ def doremi_step(state: DoremiState, ref, batch, domains):
         proxy = inference_mode(state.model, False)
         with hax.axis_mapping(trainer.compute_axis_mapping):
             # calculate per-token losses for proxy and ref
-            proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy)
-            ref_losses = loss_fn(ref, batch, reduction_axis=())
+            def scalar_loss_fn(p, batch):
+                ret, _, _ = loss_fn(p, batch)
+                return ret
+
+            proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: scalar_loss_fn(p, batch), proxy)
+            ref_losses = scalar_loss_fn(ref, batch)
 
             # calculate excess losses, aggregate per-domain losses
             excess_losses = proxy_losses - ref_losses
             clipped_losses = hax.maximum(excess_losses, 0)
+            print(clipped_losses.shape)
             per_domain_losses = _compute_per_domain_losses(clipped_losses, Domain, domains)
 
             # Update domain weights
diff --git a/src/levanter/eval.py b/src/levanter/eval.py
index ada22bc14..f58a9e2c8 100644
--- a/src/levanter/eval.py
+++ b/src/levanter/eval.py
@@ -301,8 +301,7 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample,
                 m = self.mp.cast_to_compute(m)
 
             with hax.axis_mapping(axis_mapping):
-                losses = compute_next_token_loss(m, batch, reduction=None, reduction_axis=())
-                mask = batch.loss_mask  # [Batch, Pos]
+                losses, mask, _extras = compute_next_token_loss(m, batch)
                 this_tokens = hax.sum(mask)
                 this_loss = hax.einsum("->", losses, mask)  # to scalar
 
diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py
index b9f4381fe..05ed209ff 100644
--- a/src/levanter/eval_harness.py
+++ b/src/levanter/eval_harness.py
@@ -37,7 +37,6 @@
 import levanter.tracker
 from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
 from levanter.models.gpt2 import Gpt2Config
-from levanter.models.loss import next_token_loss
 from levanter.utils.hf_utils import HfTokenizer
 
 
@@ -58,7 +57,7 @@
 import levanter.config
 from levanter.checkpoint import load_checkpoint
 from levanter.data import AsyncDataset, DataLoader
-from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
+from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss
 from levanter.trainer import StepInfo, TrainerConfig
 from levanter.utils.jax_utils import use_cpu_device
 from levanter.utils.tree_utils import inference_mode
@@ -157,15 +156,7 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr
             logits = logits.astype(jnp.float32)
             Pos = logits.resolve_axis(self.EvalPos.name)
 
-            loss = next_token_loss(
-                Pos=Pos,
-                Vocab=model.Vocab,
-                logits=logits,
-                true_ids=example.tokens,
-                loss_mask=example.loss_mask,
-                reduction=hax.sum,
-                reduction_axis=Pos,
-            )
+            loss, _, _ = compute_next_token_loss(model, example)
 
             not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool)
             pred_targets = hax.argmax(logits, axis=model.Vocab)
diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py
index 4d167c562..3b77a75bd 100644
--- a/src/levanter/grad_accum.py
+++ b/src/levanter/grad_accum.py
@@ -1,4 +1,3 @@
-import enum
 import functools
 from typing import Callable, Optional, ParamSpec, TypeVar
 
@@ -9,34 +8,31 @@
 from jax.sharding import PartitionSpec
 
 import haliax as hax
+import haliax.quantization as hq
 from haliax import Axis
 from haliax.partitioning import ResourceAxis
 from haliax.util import is_named_array
 
 from levanter.utils.jax_utils import zeros_like_tree
+from levanter.utils.types import ComputeLossFunction
 
 
 Args = ParamSpec("Args")
 R = TypeVar("R")
-
-
-class ReductionType(enum.Enum):
-    SUM = enum.auto()
-    MEAN = enum.auto()
-    # TODO: add MAX?
+M_con = TypeVar("M_con", contravariant=True)  # Model
+X = TypeVar("X", contravariant=True)  # Input
 
 
 # TODO: should we use a custom_jvp on microbatched?
 
 # cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617
 def microbatched(
-    fn: Callable[Args, R],
+    loss_fn: ComputeLossFunction[M_con, X],
     Batch: Axis,
     microbatch_size: int,
     accum_axis_mapping,
     compute_axis_mapping,
     patch_in_rng_key: Optional[str] = "key",
-    reduce: ReductionType = ReductionType.MEAN,
     accum_dtype: Optional[jnp.dtype] = None,
 ) -> Callable[Args, R]:
     """
@@ -78,20 +74,32 @@ def microbatched(
     num_micro_steps = batch_size // microbatch_size
 
     if num_micro_steps == 1:
-        return fn
+
+        @functools.wraps(loss_fn)
+        def no_accum_loss_fn(*args, **kwargs):
+            losses, where, extras = loss_fn(*args, **kwargs)
+            seen_tokens = where.sum().scalar()
+            extras["seen_tokens"] = seen_tokens
+            return hax.mean(losses, where=where).scalar(), extras
+
+        return eqx.filter_value_and_grad(no_accum_loss_fn, has_aux=True)
 
     Microbatch = Batch.resize(microbatch_size)
     AccumStep = Axis("accum_step", num_micro_steps)
     assert num_micro_steps * microbatch_size == batch_size
 
-    if reduce not in ReductionType:
-        raise ValueError(f"accum_type must be one of {ReductionType}")
+    @functools.wraps(loss_fn)
+    def accum_loss_fn(*args, **kwargs):
+        losses, where, extras = loss_fn(*args, **kwargs)
+        return hax.sum(losses, where=where).scalar(), (where.sum(), extras)
 
-    @functools.wraps(fn)
+    grad_fn = eqx.filter_value_and_grad(accum_loss_fn, has_aux=True)
+
+    @functools.wraps(grad_fn)
     def wrapped_fn(*args, **kwargs):
 
         # first, determine the shape and make accumulator arrays
-        r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
+        r_shape = eqx.filter_eval_shape(grad_fn, *args, **kwargs)
         acc = zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype)
 
         # then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...)
@@ -106,30 +114,34 @@ def wrapped_fn(*args, **kwargs):
         args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping)
 
         def loop(acc, microbatch_and_key):
+            (loss, (total, extras)), grads = acc
             microbatch, microbatch_kwargs, key = microbatch_and_key
             with jax.named_scope("compute"):
                 microbatch_kwargs = microbatch_kwargs.copy()
                 if key is not None:
                     microbatch_kwargs[patch_in_rng_key] = key
-                this_r = fn(*microbatch, **microbatch_kwargs)
+                (loss_mb, (n_mb, extras_mb)), grads_mb = grad_fn(*microbatch, **microbatch_kwargs)
 
             with jax.named_scope("accum"):
-                import haliax.quantization as hq
 
                 # TODO: this uses the latest value for the scale for fp8, which seems not ideal but probably ok?
-                overwrites, updates = hq.partition_for_grad_overwrite(this_r)
-                acc = hq.apply_updates(acc, updates, overwrites)
-                acc = hax.shard_with_axis_mapping(acc, accum_axis_mapping)
+                overwrites, updates = hq.partition_for_grad_overwrite(grads_mb)
+                grads = hq.apply_updates(grads, updates, overwrites)
+                grads = hax.shard_with_axis_mapping(grads, accum_axis_mapping)
+                loss += loss_mb
+                total += n_mb
 
-            return acc
+            return (loss, (total, {k: v + extras_mb[k] for k, v in extras.items()})), grads
 
         with jax.named_scope("microbatched"):
-            acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key))
-
-            if reduce == ReductionType.MEAN:
-                acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc)
-
-        return acc
+            (loss, (total, extras)), grads, = hax.fold(
+                loop, AccumStep
+            )(acc, (args, kwargs, key))
+            grads = jax.tree_util.tree_map(lambda x: x / total, grads)
+            loss /= total
+            extras["seen_tokens"] = total
+
+        return (loss, extras), grads
 
     return wrapped_fn
 
diff --git a/src/levanter/main/train_asr.py b/src/levanter/main/train_asr.py
index c0d52eea2..0f246a302 100644
--- a/src/levanter/main/train_asr.py
+++ b/src/levanter/main/train_asr.py
@@ -88,10 +88,8 @@ def compute_loss(
         example: AudioTextExample,
         *,
         key=None,
-        reduction: Optional[hax.ReductionFunction] = hax.mean,
-        reduction_axis: Optional[hax.AxisSelection] = None,
     ) -> jax.numpy.ndarray | hax.NamedArray:
-        return m.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis)
+        return m.compute_loss(example, key=key)
 
     # Using the trainer as a context manager does 3 things:
     # 1. Sets the device mesh
diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py
index e0aa1596d..682a29f6d 100644
--- a/src/levanter/main/viz_logprobs.py
+++ b/src/levanter/main/viz_logprobs.py
@@ -74,7 +74,7 @@ def main(config: VizGpt2Config):
         def compute_log_probs(model: LmHeadModel, example: LmExample):
             model = inference_mode(model, True)
             model = mp.cast_to_compute(model)
-            logprobs = compute_next_token_loss(model, example, reduction=None)
+            logprobs, where, _ = compute_next_token_loss(model, example)
             # roll forward to get the loss for each predicted token
             logprobs = hax.roll(logprobs, 1, Pos)
             return logprobs.rearrange((EvalBatch, Pos)).array
diff --git a/src/levanter/models/asr_model.py b/src/levanter/models/asr_model.py
index 9955dbfa5..c21acaba2 100644
--- a/src/levanter/models/asr_model.py
+++ b/src/levanter/models/asr_model.py
@@ -11,6 +11,7 @@
 
 from levanter.models.attention import AttentionMask
 from levanter.models.lm_model import LmConfig
+from levanter.utils.types import Extras
 
 
 class AudioTextExample(eqx.Module):
@@ -97,9 +98,7 @@ def compute_loss(
         example: AudioTextExample,
         *,
         key=None,
-        reduction: Optional[hax.ReductionFunction] = hax.mean,
-        reduction_axis: Optional[hax.AxisSelection] = None,
-    ) -> jnp.ndarray | NamedArray:
+    ) -> tuple[jnp.ndarray | NamedArray, NamedArray, Extras]:
         """
         Computes the cross-entropy loss for predicted ASR tokens. If reduction is not None, the loss is reduced
         across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
@@ -110,10 +109,13 @@ def compute_loss(
         targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
         target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
         loss = cross_entropy_loss(
-            logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
+            logits,
+            self.Vocab,
+            target_y,
+            reduction=None,
         )
 
-        return loss
+        return loss, example.loss_mask, {}
 
     @property
     def vocab_size(self) -> int:
diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py
index 7f5c0e3d8..52f895593 100644
--- a/src/levanter/models/lm_model.py
+++ b/src/levanter/models/lm_model.py
@@ -88,6 +88,9 @@ def from_prompt_and_completion(
 
         return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)
 
+    def num_elements(self):
+        return self.loss_mask.sum()
+
 
 # TODO: for some reason, mypy doesn't like the discover_packages_path argument?
 @dataclass(frozen=True)
@@ -219,19 +222,21 @@ def compute_next_token_loss(
     example: LmExample,
     *,
     key=None,
-    reduction: Optional[hax.ReductionFunction] = hax.mean,
-    reduction_axis: Optional[hax.AxisSelection] = None,
     logsumexp_weight: Optional[float] = None,
     loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32,
-) -> jnp.ndarray | NamedArray:
+) -> tuple[NamedArray, NamedArray, dict]:
     """
     Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced
     across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
     reduced, and the result is a named array with axes (*batch axes, sequence_length).
     """
     activations = model.activations(example.tokens, example.attn_mask, key=key)
+    if isinstance(activations, tuple):
+        activations, extras = activations
+    else:
+        extras = {}
 
-    loss = maybe_fused_next_token_loss(
+    loss, where = maybe_fused_next_token_loss(
         model.Pos,
         model.Embed,
         model.Vocab,
@@ -239,11 +244,9 @@ def compute_next_token_loss(
         model.get_lm_head(),
         example.tokens,
         loss_mask=example.loss_mask,
-        reduction=reduction,
-        reduction_axis=reduction_axis,
         logsumexp_weight=logsumexp_weight,
         dtype=loss_dtype,
         block_size=model.config.cross_entropy_block_size,
     )
 
-    return loss
+    return loss, where, extras
diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py
index bf0bd380e..bf84b0ba7 100644
--- a/src/levanter/models/loss.py
+++ b/src/levanter/models/loss.py
@@ -1,4 +1,5 @@
 import functools
+import logging
 from typing import Optional
 
 import equinox
@@ -10,6 +11,9 @@
 from haliax.nn import cross_entropy_loss_and_log_normalizers
 
 
+logger = logging.getLogger(__name__)
+
+
 def maybe_fused_next_token_loss(
     Pos: hax.AxisSelector,
     Embed: hax.AxisSelector,
@@ -18,12 +22,10 @@ def maybe_fused_next_token_loss(
     pred_lm_head: NamedArray,
     true_ids: NamedArray,
     loss_mask: Optional[NamedArray] = None,
-    reduction: Optional[hax.ReductionFunction] = hax.mean,
-    reduction_axis: Optional[hax.AxisSelection] = None,
     logsumexp_weight: Optional[float] = None,
     block_size: Optional[int] = None,
     dtype: Optional[jnp.dtype] = jnp.float32,
-) -> NamedArray:
+) -> tuple[NamedArray, NamedArray]:
     """
     Compute the next token loss with optional block-wise processing.
 
@@ -36,6 +38,7 @@ def maybe_fused_next_token_loss(
         loss_mask (Optional[NamedArray]): Mask to apply to the loss.
         reduction (Optional[hax.ReductionFunction]): Reduction function.
         reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction.
+        batch_num_elements (Optional[int]): The number of elements in the batch. When passed, it is used to reduce the loss.
         logsumexp_weight (Optional[float]): Weight for logsumexp penalty.
         block_size (Optional[int]): Size of each block for processing.
 
@@ -45,6 +48,13 @@ def maybe_fused_next_token_loss(
     # Resolve axes
     Pos = pred_embeddings.resolve_axis(Pos)
     Vocab = pred_lm_head.resolve_axis(Vocab)
+    not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)  # type: ignore
+    if loss_mask is not None:
+        loss_mask = loss_mask * not_last_loss_mask
+    else:
+        loss_mask = not_last_loss_mask
+
+    target_y = hax.roll(true_ids, -1, Pos)
 
     if block_size is None:
         # Full softmax computation
@@ -53,42 +63,29 @@ def maybe_fused_next_token_loss(
             logits = logits.astype(dtype)
 
         # Shift target tokens to predict the next token
-        return next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight)
-
-    # Shift target tokens to predict the next token
-    target_y = hax.roll(true_ids, -1, Pos)
-
-    # Create a mask that excludes the last token
-    not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)  # type: ignore
-    if loss_mask is not None:
-        loss_mask = loss_mask * not_last_loss_mask
+        return next_token_loss(Pos, Vocab, logits, target_y, logsumexp_weight), loss_mask
     else:
-        loss_mask = not_last_loss_mask
-
-    # Compute the loss with optional block-wise processing
-    return fused_cross_entropy_loss_and_logsumexp_penalty(
-        pred_embeddings,
-        pred_lm_head,
-        Contract=Embed,
-        Label=Vocab,
-        target_y=target_y,
-        reduction=reduction,
-        reduction_axis=reduction_axis,
-        where=loss_mask,
-        logsumexp_weight=logsumexp_weight,
-        block_size=block_size,
-        dtype=dtype,
-    )
+        # Compute the loss with optional block-wise processing
+        return (
+            fused_cross_entropy_loss_and_logsumexp_penalty(
+                pred_embeddings,
+                pred_lm_head,
+                Contract=Embed,
+                Label=Vocab,
+                target_y=target_y,
+                logsumexp_weight=logsumexp_weight,
+                block_size=block_size,
+                dtype=dtype,
+            ),
+            loss_mask,
+        )
 
 
 def next_token_loss(
     Pos: hax.AxisSelector,
     Vocab: hax.AxisSelector,
     logits: NamedArray,
-    true_ids: NamedArray,
-    loss_mask: Optional[NamedArray] = None,
-    reduction: Optional[hax.ReductionFunction] = hax.mean,
-    reduction_axis: Optional[hax.AxisSelection] = None,
+    target_y: NamedArray,
     logsumexp_weight: Optional[float] = None,
 ):
     """
@@ -108,23 +105,12 @@ def next_token_loss(
     """
     Pos = logits.resolve_axis(Pos)
 
-    target_y = hax.roll(true_ids, -1, Pos)
     target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=logits.dtype)
 
-    # Create a mask that excludes the last token
-    not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)  # type: ignore
-    if loss_mask is not None:
-        loss_mask = loss_mask * not_last_loss_mask
-    else:
-        loss_mask = not_last_loss_mask
-
     return cross_entropy_and_logsumexp_penalty(
         Vocab=Vocab,
         pred_y=logits,
         target_y=target_y_full,
-        reduction=reduction,
-        reduction_axis=reduction_axis,
-        where=loss_mask,
         logsumexp_weight=logsumexp_weight,
     )
 
@@ -134,9 +120,6 @@ def cross_entropy_and_logsumexp_penalty(
     pred_y: NamedArray,
     target_y: NamedArray,
     *,
-    reduction: Optional[hax.ReductionFunction] = hax.mean,
-    reduction_axis: Optional[hax.AxisSelection] = None,
-    where: Optional[NamedArray] = None,
     logsumexp_weight=0.0,
 ) -> NamedArray:
     """A loss function that combines cross entropy loss with a logsumexp penalty."""
@@ -146,7 +129,7 @@ def cross_entropy_and_logsumexp_penalty(
     if logsumexp_weight is not None and logsumexp_weight != 0.0:
         loss = loss + logsumexp_weight * (log_normalizers**2)
 
-    return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where)
+    return loss
 
 
 def fused_cross_entropy_loss_and_logsumexp_penalty(
@@ -156,9 +139,6 @@ def fused_cross_entropy_loss_and_logsumexp_penalty(
     Label: hax.AxisSelector,
     target_y: NamedArray,
     *,
-    reduction: Optional[hax.ReductionFunction] = hax.mean,
-    reduction_axis: Optional[hax.AxisSelection] = None,
-    where: Optional[NamedArray] = None,
     logsumexp_weight: float | None = 0.0,
     block_size: int,
     dtype: Optional[jnp.dtype] = jnp.float32,
@@ -192,7 +172,7 @@ def fused_cross_entropy_loss_and_logsumexp_penalty(
     if logsumexp_weight is not None and (not isinstance(logsumexp_weight, (int, float)) or logsumexp_weight != 0.0):
         loss = loss + logsumexp_weight * (log_normalizers**2)
 
-    return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where)
+    return loss
 
 
 @equinox.filter_custom_vjp
diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py
index 82f32422a..cbb272dcd 100644
--- a/src/levanter/trainer.py
+++ b/src/levanter/trainer.py
@@ -43,7 +43,7 @@
 from levanter.utils import cloud_utils, fsspec_utils
 from levanter.utils.jax_utils import create_fsdp_mesh, zeros_like_tree
 from levanter.utils.tree_utils import inference_mode
-from levanter.utils.types import ComputeLossFunction, FilterSpec
+from levanter.utils.types import ComputeLossFunction, Extras, FilterSpec
 
 
 logger = pylogging.getLogger(__name__)
@@ -380,7 +380,7 @@ def checkpoint_path(self) -> str:
             checkpoint_path = self.config.checkpointer.expanded_path(self.run_id)
         return checkpoint_path
 
-    def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
+    def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]:
         """
         Performs a single training step.
         """
@@ -391,10 +391,10 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
 
         with capture_time() as step_time:
             if hooks_this_time:
-                loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs)
+                loss, new_state, extras, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs)
                 # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?)
             else:
-                loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
+                loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
             loss = loss.item()  # type: ignore
 
             info = StepInfo(new_state, loss, step_time())
@@ -404,7 +404,8 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
                 if hooks_this_time:
                     self.hooks.run_jit_hooks_outside_step(info, cb_states)
 
-            levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step)
+            log_items = {k: v.item() for k, v in extras.items()} | {"throughput/hook_time": hook_time()}
+            levanter.tracker.log(log_items, step=info.step)
 
         return info
 
@@ -525,11 +526,13 @@ def _jit_train_step_fn_no_hook(self):
 
     def _train_step(
         self, state: S, batch, batch_kwargs, _no_hooks=False
-    ) -> tuple[Scalar, S, Sequence[CBInfo]] | tuple[Scalar, S]:
+    ) -> tuple[Scalar, S, Extras, Sequence[CBInfo]] | tuple[Scalar, S, Extras]:
         key, new_key = jax.random.split(state.training_key)
         model = inference_mode(state.model, False)
 
-        loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
+        (loss, extras), grads = self._compute_gradients_microbatched(
+            self.loss_fn, model, batch, **batch_kwargs, key=key
+        )
 
         with hax.axis_mapping(self.parameter_axis_mapping):
             if not _no_hooks:
@@ -545,22 +548,23 @@ def obj_fun(trainable_model):
         new_state = state.take_step(grads, obj_fun=obj_fun)
         new_state = hax.shard(new_state, self.parameter_axis_mapping)
         if _no_hooks:
-            return loss, new_state
+            return loss, new_state, extras
         else:
-            return loss, new_state, hook_infos
+            return loss, new_state, extras, hook_infos
 
-    def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]:
-        grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
+    def _compute_gradients_microbatched(
+        self, loss_fn, model: M, batch: X, **batch_kwargs
+    ) -> tuple[tuple[Scalar, Extras], M]:
         mbs = self.config.microbatch_size
         grad_fn = microbatched(
-            grad_fn,
+            loss_fn,
             self.TrainBatch,
             mbs,
             self.parameter_axis_mapping,
             self.compute_axis_mapping,
-        )
+        )  # type: ignore
         with hax.axis_mapping(self.compute_axis_mapping):
-            return grad_fn(model, *batch, **batch_kwargs)
+            return grad_fn(model, batch, **batch_kwargs)
 
 
 def _initialize_global_tracker(config, run_id):
diff --git a/src/levanter/utils/stat_utils.py b/src/levanter/utils/stat_utils.py
index 6111be42e..895003cf4 100644
--- a/src/levanter/utils/stat_utils.py
+++ b/src/levanter/utils/stat_utils.py
@@ -1,13 +1,26 @@
-import typing
+from typing import TypeAlias
 
 import equinox as eqx
 import jax.numpy as jnp
 import numpy as np
+from typing_extensions import Self
 
 import haliax as hax
 
+from levanter.utils.types import Accumulatable
 
-Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray
+
+Arrayish: TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray
+
+
+class SumScalar(Accumulatable):
+    value: jnp.ndarray
+
+    def item(self) -> float:
+        return self.value.item()
+
+    def __add__(self, other: Self) -> Self:
+        return SumScalar(self.value + other.value)
 
 
 class RunningMean(eqx.Module):
@@ -27,6 +40,9 @@ def add(self, x: Arrayish, total: Arrayish) -> "RunningMean":
         new_total = self.total + total
         return RunningMean(new_mean, new_total)
 
+    def item(self) -> float:
+        return self.mean.item()
+
     def __add__(self, other: "RunningMean"):
         return self.add(other.mean, other.total)
 
diff --git a/src/levanter/utils/types.py b/src/levanter/utils/types.py
index 46ccac2b5..dafc0791d 100644
--- a/src/levanter/utils/types.py
+++ b/src/levanter/utils/types.py
@@ -1,6 +1,10 @@
-from typing import Any, Callable, Optional, Protocol, Tuple, TypeVar, Union
+import abc
+from typing import Any, Callable, Dict, Protocol, Tuple, TypeAlias, TypeVar, Union
 
+import equinox as eqx
+import jax
 from jaxtyping import PyTree
+from typing_extensions import Self
 
 import haliax as hax
 from haliax.types import Scalar
@@ -10,6 +14,19 @@
 M_con = TypeVar("M_con", contravariant=True)  # Model
 X = TypeVar("X", contravariant=True)  # Input
 
+
+class Accumulatable(abc.ABC, eqx.Module):
+    @abc.abstractmethod
+    def item(self) -> float:
+        pass
+
+    @abc.abstractmethod
+    def __add__(self, other: Self) -> Self:
+        pass
+
+
+Extras: TypeAlias = Dict[str, jax.Array | Accumulatable]
+
 try:
     from haliax.nn.scan import BlockFoldable
 except ImportError:
@@ -51,9 +68,7 @@ class ComputeLossFunction(Protocol[M_con, X]):
     def __call__(
         self,
         model: M_con,
-        *inputs: X,
-        reduction: Optional[hax.ReductionFunction] = hax.mean,
-        reduction_axis: Optional[hax.AxisSelection] = None,
+        input: X,
         **kwargs,
-    ) -> Scalar | hax.NamedArray:
+    ) -> tuple[hax.NamedArray, hax.NamedArray, Extras]:
         ...
diff --git a/tests/test_doremi.py b/tests/test_doremi.py
index 3ad4aa9ab..e60554480 100644
--- a/tests/test_doremi.py
+++ b/tests/test_doremi.py
@@ -128,10 +128,11 @@ def test_estimate_mixture_weights():
         ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys))
 
     # TODO: remove key as a requirement for models
-    def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None):
+    def compute_loss_fn(model, example, key=None):
         del key
         y_pred = model(example.x)
-        return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis)
+        losses = hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=None)
+        return losses, hax.ones_like(losses), {}
 
     tiny_trainer_config = TrainerConfig(
         num_train_steps=300,
diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py
index a6e96f2d8..77ef4df7e 100644
--- a/tests/test_grad_accum.py
+++ b/tests/test_grad_accum.py
@@ -45,7 +45,13 @@ def test_accumulate_gradients_sharded(parallelism, accum_steps):
     mlp = Mlp.init(In, Out, Mid, key=jax.random.PRNGKey(0))
 
     def loss_fn(mlp, x):
-        return mlp(x).mean().scalar()
+        out = mlp(x)
+        where = hax.ones_like(out)
+        return out, where, {}
+
+    def scalar_loss_fn(mlp, x):
+        out, where, _ = loss_fn(mlp, x)
+        return hax.mean(out, where=where).scalar(), {}
 
     x = hax.random.normal(jax.random.PRNGKey(0), (Batch, In))
 
@@ -57,20 +63,15 @@ def loss_fn(mlp, x):
 
     @hax.partitioning.named_jit(axis_resources=axis_mapping)
     def jit_grad_accum(mlp, x):
-        grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
-        grad_fn = microbatched(grad_fn, Batch, parallelism, axis_mapping, axis_mapping)
-        acc_v, acc_g = grad_fn(
-            mlp,
-            x,
-        )
-        return acc_v, acc_g
+        grad_fn = microbatched(loss_fn, Batch, parallelism, axis_mapping, axis_mapping)
+        return grad_fn(mlp, x)
 
     with mesh:
         mlp = haliax.shard(mlp, axis_mapping)
         x = haliax.shard(x, axis_mapping)
-        grad_fn = eqx.filter_value_and_grad(loss_fn)
-        acc_v, acc_g = jit_grad_accum(mlp, x)
-        v, g = grad_fn(mlp, x)
+        grad_fn = eqx.filter_value_and_grad(scalar_loss_fn, has_aux=True)
+        (acc_v, _), acc_g = jit_grad_accum(mlp, x)
+        (v, _), g = grad_fn(mlp, x)
 
         assert_trees_all_close(acc_v, v, atol=1e-3, rtol=1e-3)
 
diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py
index 9cc46ca0d..29074504c 100644
--- a/tests/test_hf_gpt2_serialize.py
+++ b/tests/test_hf_gpt2_serialize.py
@@ -135,7 +135,8 @@ def torch_loss(model, input_ids) -> torch.Tensor:
 
     def compute_loss(model: LmHeadModel, input_ids):
         example = LmExample.causal(input_ids, eos_id=converter.tokenizer.eos_token_id)
-        return compute_next_token_loss(model, example, key=None).scalar()
+        loss, where, _ = compute_next_token_loss(model, example, key=None)
+        return hax.mean(loss, where=where).scalar()
 
     jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False)
     jax_grad: Gpt2LMHeadModel
diff --git a/tests/test_text.py b/tests/test_text.py
index 63d0afedb..046b7a7d2 100644
--- a/tests/test_text.py
+++ b/tests/test_text.py
@@ -40,12 +40,14 @@ def test_lm_example_handles_ignore_id():
     lm_head = hax.zeros((Embed, Vocab))
     lm_head = lm_head.at[Vocab, ignore_id].set(-100)
 
-    ignored_loss = maybe_fused_next_token_loss(
+    ignored_loss, ignored_where = maybe_fused_next_token_loss(
         Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask
     )
-    no_ignore_loss = maybe_fused_next_token_loss(
+    ignored_loss = hax.sum(ignored_loss, where=ignored_where)
+    no_ignore_loss, no_ignore_where = maybe_fused_next_token_loss(
         Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask
     )
+    no_ignore_loss = hax.sum(no_ignore_loss, where=no_ignore_where)
 
     assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size