-
Notifications
You must be signed in to change notification settings - Fork 218
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #216 from sunlabuiuc/develop
Base fairness metrics and example
- Loading branch information
Showing
8 changed files
with
208 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from pyhealth.datasets import MIMIC3Dataset | ||
from pyhealth.tasks import readmission_prediction_mimic3_fn | ||
from pyhealth.datasets import split_by_patient, get_dataloader | ||
from pyhealth.metrics import fairness_metrics_fn | ||
from pyhealth.models import Transformer | ||
from pyhealth.trainer import Trainer | ||
from pyhealth.metrics.fairness_utils.utils import sensitive_attributes_from_patient_ids | ||
|
||
# STEP 1: load data | ||
base_dataset = MIMIC3Dataset( | ||
root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", | ||
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], | ||
) | ||
base_dataset.stat() | ||
|
||
# STEP 2: set task | ||
sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) | ||
sample_dataset.stat() | ||
|
||
train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) | ||
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) | ||
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) | ||
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) | ||
|
||
# STEP 3: define model | ||
model = Transformer( | ||
dataset=sample_dataset, | ||
# look up what are available for "feature_keys" and "label_keys" in dataset.samples[0] | ||
feature_keys=["conditions", "procedures"], | ||
label_key="label", | ||
mode="binary", | ||
) | ||
|
||
# STEP 4: define trainer | ||
trainer = Trainer(model=model) | ||
trainer.train( | ||
train_dataloader=train_dataloader, | ||
val_dataloader=val_dataloader, | ||
epochs=3, | ||
monitor="pr_auc", | ||
) | ||
|
||
# STEP 5: inference, return patient_ids | ||
y_true, y_prob, loss, patient_ids = trainer.inference(test_dataloader, return_patient_ids=True) | ||
|
||
# STEP 6: get sensitive attribute array from patient_ids | ||
sensitive_attribute_array = sensitive_attributes_from_patient_ids(base_dataset, patient_ids, | ||
'gender', 'F') | ||
|
||
# STEP 7: use pyhealth.metrics to evaluate fairness | ||
fairness_metrics = fairness_metrics_fn(y_true, y_prob, sensitive_attribute_array, | ||
favorable_outcome=0) | ||
print(fairness_metrics) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Dict, List, Optional | ||
|
||
import numpy as np | ||
|
||
from pyhealth.metrics.fairness_utils import disparate_impact, statistical_parity_difference | ||
|
||
def fairness_metrics_fn( | ||
y_true: np.ndarray, | ||
y_prob: np.ndarray, | ||
sensitive_attributes: np.ndarray, | ||
favorable_outcome: int = 1, | ||
metrics: Optional[List[str]] = None, | ||
threshold: float = 0.5, | ||
) -> Dict[str, float]: | ||
"""Computes metrics for binary classification. | ||
User can specify which metrics to compute by passing a list of metric names. | ||
The accepted metric names are: | ||
- disparate_impact: | ||
- statistical_parity_difference: | ||
If no metrics are disparate_impact, and statistical_parity_difference are computed by default. | ||
Args: | ||
y_true: True target values of shape (n_samples,). | ||
y_prob: Predicted probabilities of shape (n_samples,). | ||
sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. | ||
favorable_outcome: Label value which is considered favorable (i.e. "positive"). | ||
metrics: List of metrics to compute. Default is ["disparate_impact", "statistical_parity_difference"]. | ||
threshold: Threshold for binary classification. Default is 0.5. | ||
Returns: | ||
Dictionary of metrics whose keys are the metric names and values are | ||
the metric values. | ||
""" | ||
if metrics is None: | ||
metrics = ["disparate_impact", "statistical_parity_difference"] | ||
|
||
y_pred = y_prob.copy() | ||
y_pred[y_pred >= threshold] = 1 | ||
y_pred[y_pred < threshold] = 0 | ||
|
||
output = {} | ||
for metric in metrics: | ||
if metric == "disparate_impact": | ||
output[metric] = disparate_impact(sensitive_attributes, y_pred, favorable_outcome) | ||
elif metric == "statistical_parity_difference": | ||
output[metric] = statistical_parity_difference(sensitive_attributes, y_pred, favorable_outcome) | ||
else: | ||
raise ValueError(f"Unknown metric for fairness: {metric}") | ||
return output | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .group import disparate_impact, statistical_parity_difference | ||
from .utils import sensitive_attributes_from_patient_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
|
||
""" | ||
Notation: | ||
- Protected group: P | ||
- Unprotected group: U | ||
""" | ||
|
||
def disparate_impact(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1, allow_zero_division = False, epsilon: float = 1e-8) -> float: | ||
""" | ||
Computes the disparate impact between the the protected and unprotected group. | ||
disparate_impact = P(y_pred = favorable_outcome | P) / P(y_pred = favorable_outcome | U) | ||
Args: | ||
sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. | ||
y_pred: Predicted target values of shape (n_samples,). | ||
favorable_outcome: Label value which is considered favorable (i.e. "positive"). | ||
allow_zero_division: If True, use epsilon instead of 0 in the denominator if the denominator is 0. Otherwise, raise a ValueError. | ||
Returns: | ||
The disparate impact between the protected and unprotected group. | ||
""" | ||
|
||
p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0]) | ||
p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1]) | ||
|
||
if p_fav_unpr == 0: | ||
if allow_zero_division: | ||
p_fav_unpr = epsilon | ||
else: | ||
raise ValueError("Unprotected group has no instances with a favorable outcome. Disparate impact is undefined.") | ||
|
||
disparate_impact_value = p_fav_prot / p_fav_unpr | ||
|
||
return disparate_impact_value | ||
|
||
def statistical_parity_difference(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1) -> float: | ||
""" | ||
Computes the statistical parity difference between the the protected and unprotected group. | ||
statistical_parity_difference = P(y_pred = favorable_outcome | P) - P(y_pred = favorable_outcome | U) | ||
Args: | ||
sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. | ||
y_pred: Predicted target values of shape (n_samples,). | ||
favorable_outcome: Label value which is considered favorable (i.e. "positive"). | ||
Returns: | ||
The statistical parity difference between the protected and unprotected group. | ||
""" | ||
|
||
p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0]) | ||
p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1]) | ||
|
||
statistical_parity_difference_value = p_fav_prot - p_fav_unpr | ||
|
||
return statistical_parity_difference_value | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
|
||
from typing import List | ||
import numpy as np | ||
|
||
from pyhealth.datasets import BaseEHRDataset | ||
|
||
def sensitive_attributes_from_patient_ids(dataset: BaseEHRDataset, | ||
patient_ids: List[str], | ||
sensitive_attribute: str, | ||
protected_group: str) -> np.ndarray: | ||
""" | ||
Returns the desired sensitive attribute array from patient_ids. | ||
Args: | ||
dataset: Dataset object. | ||
patient_ids: List of patient IDs. | ||
sensitive_attribute: Sensitive attribute to extract. | ||
protected_group: Value of the protected group. | ||
Returns: | ||
Sensitive attribute array of shape (n_samples,). | ||
""" | ||
|
||
sensitive_attribute_array = np.zeros(len(patient_ids)) | ||
for idx, patient_id in enumerate(patient_ids): | ||
sensitive_attribute_value = getattr(dataset.patients[patient_id], sensitive_attribute) | ||
if sensitive_attribute_value == protected_group: | ||
sensitive_attribute_array[idx] = 1 | ||
return sensitive_attribute_array | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters