Skip to content

Commit

Permalink
add sample_size visualization and update notebooks (#667)
Browse files Browse the repository at this point in the history
* add sample_size vis and update notebooks

* fix mypy error

* change Group Size to sample_size
  • Loading branch information
a-kore authored Jul 24, 2024
1 parent fa9f19c commit 332898e
Show file tree
Hide file tree
Showing 16 changed files with 540 additions and 253 deletions.
4 changes: 2 additions & 2 deletions benchmarks/mimiciv/discharge_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1182,9 +1182,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
"# remove the group size from the fairness results and add it to the slice name\n",
"# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
" group_size = slice_results.pop(\"Group Size\")\n",
" group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/mimiciv/icu_mortality_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1159,9 +1159,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
"# remove the group size from the fairness results and add it to the slice name\n",
"# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
" group_size = slice_results.pop(\"Group Size\")\n",
" group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
Expand Down
1 change: 1 addition & 0 deletions cyclops/evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def _compute_metrics(
model_name: str = "model_for_%s" % prediction_column
results.setdefault(model_name, {})
results[model_name][slice_name] = metric_output
results[model_name][slice_name]["sample_size"] = len(sliced_dataset)

set_decode(dataset, True) # restore decoding features

Expand Down
4 changes: 2 additions & 2 deletions cyclops/evaluate/fairness/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def evaluate_fairness( # noqa: PLR0912
for prediction_column in fmt_prediction_columns:
results.setdefault(prediction_column, {})
results[prediction_column].setdefault(slice_name, {}).update(
{"Group Size": len(sliced_dataset)},
{"sample_size": len(sliced_dataset)},
)

pred_result = _get_metric_results_for_prediction_and_slice(
Expand Down Expand Up @@ -966,7 +966,7 @@ def _compute_parity_metrics(
parity_results[key] = {}
for slice_name, slice_result in prediction_result.items():
for metric_name, metric_value in slice_result.items():
if metric_name == "Group Size":
if metric_name == "sample_size":
continue

# add 'Parity' to the metric name before @threshold, if specified
Expand Down
10 changes: 10 additions & 0 deletions cyclops/report/model_card/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ class PerformanceMetric(
default_factory=list,
)

sample_size: Optional[StrictInt] = Field(
None,
description="The sample size used to compute this metric.",
)


class User(
BaseModelCardField,
Expand Down Expand Up @@ -599,6 +604,11 @@ class MetricCard(
description="Timestamps for each point in the history.",
)

sample_sizes: Optional[List[int]] = Field(
None,
description="Sample sizes for each point in the history.",
)


class MetricCardCollection(BaseModelCardField, composable_with="Overview"):
"""A collection of metric cards to be displayed in the model card."""
Expand Down
92 changes: 62 additions & 30 deletions cyclops/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
get_histories,
get_names,
get_passed,
get_sample_sizes,
get_slices,
get_thresholds,
get_timestamps,
Expand Down Expand Up @@ -855,6 +856,7 @@ def log_quantitative_analysis(
pass_fail_threshold_fns: Optional[
Union[Callable[[Any, float], bool], List[Callable[[Any, float], bool]]]
] = None,
sample_size: Optional[int] = None,
**extra: Any,
) -> None:
"""Add a quantitative analysis to the report.
Expand Down Expand Up @@ -921,6 +923,7 @@ def log_quantitative_analysis(
"slice": metric_slice,
"decision_threshold": decision_threshold,
"description": description,
"sample_size": sample_size,
**extra,
}

Expand Down Expand Up @@ -958,42 +961,70 @@ def log_quantitative_analysis(
field_type=field_type,
)

def log_performance_metrics(self, metrics: Dict[str, Any]) -> None:
"""Add a performance metric to the `Quantitative Analysis` section.
def log_performance_metrics(
self,
results: Dict[str, Any],
metric_descriptions: Dict[str, str],
pass_fail_thresholds: Union[float, Dict[str, float]] = 0.7,
pass_fail_threshold_fn: Callable[[float, float], bool] = lambda x,
threshold: bool(x >= threshold),
) -> None:
"""
Log all performance metrics to the model card report.
Parameters
----------
metrics : Dict[str, Any]
A dictionary of performance metrics. The keys should be the name of the
metric, and the values should be the value of the metric. If the metric
is a slice metric, the key should be the slice name followed by a slash
and then the metric name (e.g. "slice_name/metric_name"). If no slice
name is provided, the slice name will be "overall".
Raises
------
TypeError
If the given metrics are not a dictionary with string keys.
results : Dict[str, Any]
Dictionary containing the results,
with keys in the format "split/metric_name".
metric_descriptions : Dict[str, str]
Dictionary mapping metric names to their descriptions.
pass_fail_thresholds : Union[float, Dict[str, float]], optional
The threshold(s) for pass/fail tests.
Can be a single float applied to all metrics,
or a dictionary mapping "split/metric_name" to individual thresholds.
Default is 0.7.
pass_fail_threshold_fn : Callable[[float, float], bool], optional
Function to determine if a metric passes or fails.
Default is lambda x, threshold: bool(x >= threshold).
Returns
-------
None
"""
_raise_if_not_dict_with_str_keys(metrics)
for metric_name, metric_value in metrics.items():
name_split = metric_name.split("/")
if len(name_split) == 1:
slice_name = "overall"
metric_name = name_split[0] # noqa: PLW2901
else: # everything before the last slash is the slice name
slice_name = "/".join(name_split[:-1])
metric_name = name_split[-1] # noqa: PLW2901

# TODO: create plot
# Extract sample sizes
sample_sizes = {
key.split("/")[0]: value
for key, value in results.items()
if "sample_size" in key.split("/")[1]
}

self._log_field(
data={"type": metric_name, "value": metric_value, "slice": slice_name},
section_name="quantitative_analysis",
field_name="performance_metrics",
field_type=PerformanceMetric,
)
# Log metrics
for name, metric in results.items():
split, metric_name = name.split("/")
if metric_name != "sample_size":
metric_value = metric.tolist() if hasattr(metric, "tolist") else metric

# Determine the threshold for this specific metric
if isinstance(pass_fail_thresholds, dict):
threshold = pass_fail_thresholds.get(
name, 0.7
) # Default to 0.7 if not specified
else:
threshold = pass_fail_thresholds

self.log_quantitative_analysis(
"performance",
name=metric_name,
value=metric_value,
description=metric_descriptions.get(
metric_name, "No description provided."
),
metric_slice=split,
pass_fail_thresholds=threshold,
pass_fail_threshold_fns=pass_fail_threshold_fn,
sample_size=sample_sizes.get(split),
)

# TODO: MERGE/COMPARE MODEL CARDS

Expand Down Expand Up @@ -1162,6 +1193,7 @@ def export(
"get_names": get_names,
"get_histories": get_histories,
"get_timestamps": get_timestamps,
"get_sample_sizes": get_sample_sizes,
}
template.globals.update(func_dict)

Expand Down
Loading

0 comments on commit 332898e

Please sign in to comment.