Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric and Collection inherit from PyTreeNode #251

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions clu/metric_writers/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@



@flax.struct.dataclass
class HistogramMetric(clu.metrics.Metric):
value: jnp.ndarray
num_buckets: int
Expand All @@ -43,15 +42,13 @@ def compute_value(self):
return values.Histogram(self.value, self.num_buckets)


@flax.struct.dataclass
class ImageMetric(clu.metrics.Metric):
value: jnp.ndarray

def compute_value(self):
return values.Image(self.value)


@flax.struct.dataclass
class AudioMetric(clu.metrics.Metric):
value: jnp.ndarray
sample_rate: int
Expand All @@ -60,23 +57,20 @@ def compute_value(self):
return values.Audio(self.value, self.sample_rate)


@flax.struct.dataclass
class TextMetric(clu.metrics.Metric):
value: str

def compute_value(self):
return values.Text(self.value)


@flax.struct.dataclass
class HyperParamMetric(clu.metrics.Metric):
value: float

def compute_value(self):
return values.HyperParam(self.value)


@flax.struct.dataclass
class SummaryMetric(clu.metrics.Metric):
value: jnp.ndarray
metadata: Any
Expand Down
27 changes: 5 additions & 22 deletions clu/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import flax
import jax

@flax.struct.dataclass # required for jax.tree_*
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
Expand Down Expand Up @@ -66,6 +65,7 @@ def evaluate(model, p_variables, test_ds):
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct

# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
Expand All @@ -78,7 +78,7 @@ def _assert_same_shape(a: jnp.array, b: jnp.array):
raise ValueError(f"Expected same shape: {a.shape} != {b.shape}")


class Metric:
class Metric(struct.PyTreeNode):
"""Interface for computing metrics from intermediate values.

Refer to `Collection` for computing multipel metrics at the same time.
Expand All @@ -88,7 +88,6 @@ class Metric:
import jax.numpy as jnp
import flax

@flax.struct.dataclass
class Average(Metric):
total: jnp.array
count: jnp.array
Expand Down Expand Up @@ -184,7 +183,6 @@ def from_fun(cls, fun: Callable): # pylint: disable=g-bare-generic
def get_head1(head1_loss, head1_mask, **_):
return dict(loss=head1_loss, mask=head1_mask)

@flax.struct.dataclass
class MultiHeadMetrics(metrics.Collection):
head1_loss: metrics.Average.from_output("loss").from_fun(get_head1)
...
Expand All @@ -202,7 +200,6 @@ class MultiHeadMetrics(metrics.Collection):
`model_output`.
"""

@flax.struct.dataclass
class FromFun(cls):
"""Wrapper Metric class that collects output after applying `fun`."""

Expand Down Expand Up @@ -245,7 +242,6 @@ def from_output(cls, name: str): # pylint: disable=g-bare-generic

Synopsis:

@flax.struct.dataclass
class Metrics(Collection):
loss: Average.from_output('loss')

Expand All @@ -264,7 +260,6 @@ class Metrics(Collection):
a first argument the model output specified by `name`.
"""

@flax.struct.dataclass
class FromOutput(cls):
"""Wrapper Metric class that collects output named `name`."""

Expand All @@ -283,7 +278,6 @@ def from_model_output(cls, **model_output) -> Metric:
return FromOutput


@flax.struct.dataclass
class CollectingMetric(Metric):
"""A special metric that collects model outputs.

Expand All @@ -307,7 +301,6 @@ class CollectingMetric(Metric):

Example to use compute average precision using `sklearn`:

@flax.struct.dataclass
class AveragePrecision(
metrics.CollectingMetric.from_outputs(("labels", "logits"))):

Expand Down Expand Up @@ -371,7 +364,6 @@ def compute(self) -> Dict[str, np.ndarray]:
def from_outputs(cls, names: Sequence[str]):
"""Returns a metric class that collects all model outputs named `names`."""

@flax.struct.dataclass
class FromOutputs(cls): # pylint:disable=missing-class-docstring

@classmethod
Expand All @@ -387,7 +379,6 @@ def make_array(value):
return FromOutputs


@flax.struct.dataclass
class _ReductionCounter(Metric):
"""Pseudo metric that keeps track of the total number of `.merge()`."""

Expand All @@ -409,15 +400,13 @@ def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter):
f"call a flax.jax_utils.unreplicate() or a Collections.reduce()?")


@flax.struct.dataclass
class Collection:
class Collection(struct.PyTreeNode):
"""Updates a collection of `Metric` from model outputs.

Refer to the module documentation for a complete example.

Synopsis:

@flax.struct.dataclass
class Metrics(Collection):
accuracy: Accuracy

Expand All @@ -437,7 +426,6 @@ def create(cls, **metrics: Type[Metric]) -> Type["Collection"]:

Instead declaring a `Collection` dataclass:

@flax.struct.dataclass
class MyMetrics(metrics.Collection):
accuracy: metrics.Accuracy

Expand All @@ -454,8 +442,8 @@ class MyMetrics(metrics.Collection):
Returns:
A subclass of Collection with fields defined by provided `metrics`.
"""
return flax.struct.dataclass(
type("_InlineCollection", (Collection,), {"__annotations__": metrics}))
return type(
"_InlineCollection", (Collection,), {"__annotations__": metrics})

@classmethod
def create_collection(cls, **metrics: Metric) -> "Collection":
Expand All @@ -469,7 +457,6 @@ def create_collection(cls, **metrics: Metric) -> "Collection":

is equivalent to:

@flax.struct.dataclass
class MyMetrics(metrics.Collection):
accuracy: metrics.Accuracy
my_metrics = MyMetrics(_ReductionCounter(jnp.array(1)),
Expand Down Expand Up @@ -574,7 +561,6 @@ def unreplicate(self) -> "Collection":
return flax.jax_utils.unreplicate(self)


@flax.struct.dataclass
class LastValue(Metric):
"""Keeps the last average batch value.

Expand Down Expand Up @@ -604,7 +590,6 @@ def compute(self) -> Any:
return self.value


@flax.struct.dataclass
class Average(Metric):
"""Computes the average of a scalar or a batch of tensors.

Expand Down Expand Up @@ -665,7 +650,6 @@ def compute(self) -> Any:
return self.total / self.count


@flax.struct.dataclass
class Std(Metric):
"""Computes the standard deviation of a scalar or a batch of scalars.

Expand Down Expand Up @@ -722,7 +706,6 @@ def compute(self) -> Any:
return variance**.5


@flax.struct.dataclass
class Accuracy(Average):
"""Computes the accuracy from model outputs `logits` and `labels`.

Expand Down
3 changes: 0 additions & 3 deletions clu/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import numpy as np


@flax.struct.dataclass
class CollectingMetricAccuracy(
metrics.CollectingMetric.from_outputs(("logits", "labels"))):

Expand All @@ -41,13 +40,11 @@ def compute(self):
return (logits.argmax(axis=-1) == labels).mean()


@flax.struct.dataclass
class Collection(metrics.Collection):
train_accuracy: metrics.Accuracy
learning_rate: metrics.LastValue.from_output("learning_rate")


@flax.struct.dataclass
class CollectionMixed(metrics.Collection):
collecting_metric_accuracy: CollectingMetricAccuracy
train_accuracy: metrics.Accuracy
Expand Down
65 changes: 31 additions & 34 deletions clu_synopsis.ipynb

Large diffs are not rendered by default.