Skip to content

Commit

Permalink
WIP: fix the state.parent for TreeReduce
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672502390
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Sep 9, 2024
1 parent 3bcf564 commit a50f9dd
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion kauldron/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def new_get_state(self, *args, **kwargs):
return new_get_state


def skip_link_metric(fn: _FnT) -> _FnT:
"""Decorator to disable the `_link_metric_to_state` magic.
This is important for wrapper metrics like `kd.metrics.TreeReduce` where
the state.parent should remain the wrapped metric.
Args:
fn: The `get_state` or `empty` function to skip.
Returns:
The function with the `_has_link_metric` flag set to True.
"""
fn._has_link_metric = True # pylint: disable=protected-access
return fn


@flax.struct.dataclass
class TreeState(base_state.State):
"""Holds a pytree of metric states."""
Expand Down Expand Up @@ -227,7 +243,7 @@ class TreeMap(_TreeMetric):
class State(TreeState):
pass

def get_state(self, **kwargs):
def get_state(self, **kwargs) -> TreeMap.State:
state_tree = self._get_tree_state(**kwargs)
return self.State(state_tree)

Expand All @@ -239,6 +255,7 @@ class TreeReduce(_TreeMetric):
The given metric defines the aggregation method.
"""

@skip_link_metric
def get_state(self, **kwargs) -> base_state.State:
state_tree = self._get_tree_state(**kwargs)
reduced_state = jax.tree.reduce(
Expand All @@ -249,6 +266,10 @@ def get_state(self, **kwargs) -> base_state.State:
)
return reduced_state

@skip_link_metric
def empty(self) -> base_state.State:
return self.metric.empty()


def _tree_map_with_kwargs(fun, **kwargs):
"""Same as jax.tree.map but taking and passing trees to fun as kwargs."""
Expand Down

0 comments on commit a50f9dd

Please sign in to comment.