Skip to content

Commit

Permalink
Merge pull request #216 from sunlabuiuc/develop
Browse files Browse the repository at this point in the history
Base fairness metrics and example
  • Loading branch information
ycq091044 authored Sep 1, 2023
2 parents 9cbfbcf + e7ef27a commit 4ffa2cc
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/api/datasets/pyhealth.datasets.OMOPDataset.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pyhealth.datasets.OMOPDataset
===================================

We can process any OMOP-CDM formatted database, refer to `doc <https://www.ohdsi.org/data-standardization/the-common-data-model/>`_ for more information. We it into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
We can process any OMOP-CDM formatted database, refer to `doc <https://www.ohdsi.org/data-standardization/the-common-data-model/>`_ for more information. The raw data is processed into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.

.. autoclass:: pyhealth.datasets.OMOPDataset
:members:
Expand Down
53 changes: 53 additions & 0 deletions examples/readmission_mimic3_fairness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import readmission_prediction_mimic3_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.metrics import fairness_metrics_fn
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer
from pyhealth.metrics.fairness_utils.utils import sensitive_attributes_from_patient_ids

# STEP 1: load data
base_dataset = MIMIC3Dataset(
root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
)
base_dataset.stat()

# STEP 2: set task
sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn)
sample_dataset.stat()

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

# STEP 3: define model
model = Transformer(
dataset=sample_dataset,
# look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
feature_keys=["conditions", "procedures"],
label_key="label",
mode="binary",
)

# STEP 4: define trainer
trainer = Trainer(model=model)
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=3,
monitor="pr_auc",
)

# STEP 5: inference, return patient_ids
y_true, y_prob, loss, patient_ids = trainer.inference(test_dataloader, return_patient_ids=True)

# STEP 6: get sensitive attribute array from patient_ids
sensitive_attribute_array = sensitive_attributes_from_patient_ids(base_dataset, patient_ids,
'gender', 'F')

# STEP 7: use pyhealth.metrics to evaluate fairness
fairness_metrics = fairness_metrics_fn(y_true, y_prob, sensitive_attribute_array,
favorable_outcome=0)
print(fairness_metrics)
1 change: 1 addition & 0 deletions pyhealth/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .drug_recommendation import ddi_rate_score
from .multiclass import multiclass_metrics_fn
from .multilabel import multilabel_metrics_fn
from .fairness import fairness_metrics_fn
52 changes: 52 additions & 0 deletions pyhealth/metrics/fairness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict, List, Optional

import numpy as np

from pyhealth.metrics.fairness_utils import disparate_impact, statistical_parity_difference

def fairness_metrics_fn(
y_true: np.ndarray,
y_prob: np.ndarray,
sensitive_attributes: np.ndarray,
favorable_outcome: int = 1,
metrics: Optional[List[str]] = None,
threshold: float = 0.5,
) -> Dict[str, float]:
"""Computes metrics for binary classification.
User can specify which metrics to compute by passing a list of metric names.
The accepted metric names are:
- disparate_impact:
- statistical_parity_difference:
If no metrics are disparate_impact, and statistical_parity_difference are computed by default.
Args:
y_true: True target values of shape (n_samples,).
y_prob: Predicted probabilities of shape (n_samples,).
sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group.
favorable_outcome: Label value which is considered favorable (i.e. "positive").
metrics: List of metrics to compute. Default is ["disparate_impact", "statistical_parity_difference"].
threshold: Threshold for binary classification. Default is 0.5.
Returns:
Dictionary of metrics whose keys are the metric names and values are
the metric values.
"""
if metrics is None:
metrics = ["disparate_impact", "statistical_parity_difference"]

y_pred = y_prob.copy()
y_pred[y_pred >= threshold] = 1
y_pred[y_pred < threshold] = 0

output = {}
for metric in metrics:
if metric == "disparate_impact":
output[metric] = disparate_impact(sensitive_attributes, y_pred, favorable_outcome)
elif metric == "statistical_parity_difference":
output[metric] = statistical_parity_difference(sensitive_attributes, y_pred, favorable_outcome)
else:
raise ValueError(f"Unknown metric for fairness: {metric}")
return output

2 changes: 2 additions & 0 deletions pyhealth/metrics/fairness_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .group import disparate_impact, statistical_parity_difference
from .utils import sensitive_attributes_from_patient_ids
58 changes: 58 additions & 0 deletions pyhealth/metrics/fairness_utils/group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np

"""
Notation:
- Protected group: P
- Unprotected group: U
"""

def disparate_impact(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1, allow_zero_division = False, epsilon: float = 1e-8) -> float:
"""
Computes the disparate impact between the the protected and unprotected group.
disparate_impact = P(y_pred = favorable_outcome | P) / P(y_pred = favorable_outcome | U)
Args:
sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group.
y_pred: Predicted target values of shape (n_samples,).
favorable_outcome: Label value which is considered favorable (i.e. "positive").
allow_zero_division: If True, use epsilon instead of 0 in the denominator if the denominator is 0. Otherwise, raise a ValueError.
Returns:
The disparate impact between the protected and unprotected group.
"""

p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0])
p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1])

if p_fav_unpr == 0:
if allow_zero_division:
p_fav_unpr = epsilon
else:
raise ValueError("Unprotected group has no instances with a favorable outcome. Disparate impact is undefined.")

disparate_impact_value = p_fav_prot / p_fav_unpr

return disparate_impact_value

def statistical_parity_difference(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1) -> float:
"""
Computes the statistical parity difference between the the protected and unprotected group.
statistical_parity_difference = P(y_pred = favorable_outcome | P) - P(y_pred = favorable_outcome | U)
Args:
sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group.
y_pred: Predicted target values of shape (n_samples,).
favorable_outcome: Label value which is considered favorable (i.e. "positive").
Returns:
The statistical parity difference between the protected and unprotected group.
"""

p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0])
p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1])

statistical_parity_difference_value = p_fav_prot - p_fav_unpr

return statistical_parity_difference_value




31 changes: 31 additions & 0 deletions pyhealth/metrics/fairness_utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

from typing import List
import numpy as np

from pyhealth.datasets import BaseEHRDataset

def sensitive_attributes_from_patient_ids(dataset: BaseEHRDataset,
patient_ids: List[str],
sensitive_attribute: str,
protected_group: str) -> np.ndarray:
"""
Returns the desired sensitive attribute array from patient_ids.
Args:
dataset: Dataset object.
patient_ids: List of patient IDs.
sensitive_attribute: Sensitive attribute to extract.
protected_group: Value of the protected group.
Returns:
Sensitive attribute array of shape (n_samples,).
"""

sensitive_attribute_array = np.zeros(len(patient_ids))
for idx, patient_id in enumerate(patient_ids):
sensitive_attribute_value = getattr(dataset.patients[patient_id], sensitive_attribute)
if sensitive_attribute_value == protected_group:
sensitive_attribute_array[idx] = 1
return sensitive_attribute_array


13 changes: 10 additions & 3 deletions pyhealth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def train(

return

def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]:
def inference(self, dataloader, additional_outputs=None, return_patient_ids=False) -> Dict[str, float]:
"""Model inference.
Args:
Expand All @@ -256,10 +256,12 @@ def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]:
y_prob_all: List of predicted probabilities.
loss_mean: Mean loss over batches.
additional_outputs (only if requested): Dict of additional results.
patient_ids (only if requested): List of patient ids in the same order as y_true_all/y_prob_all.
"""
loss_all = []
y_true_all = []
y_prob_all = []
patient_ids = []
if additional_outputs is not None:
additional_outputs = {k: [] for k in additional_outputs}
for data in tqdm(dataloader, desc="Evaluation"):
Expand All @@ -275,14 +277,19 @@ def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]:
if additional_outputs is not None:
for key in additional_outputs.keys():
additional_outputs[key].append(output[key].cpu().numpy())
if return_patient_ids:
patient_ids.extend(data["patient_id"])
loss_mean = sum(loss_all) / len(loss_all)
y_true_all = np.concatenate(y_true_all, axis=0)
y_prob_all = np.concatenate(y_prob_all, axis=0)
outputs = [y_true_all, y_prob_all, loss_mean]
if additional_outputs is not None:
additional_outputs = {key: np.concatenate(val)
for key, val in additional_outputs.items()}
return y_true_all, y_prob_all, loss_mean, additional_outputs
return y_true_all, y_prob_all, loss_mean
outputs.append(additional_outputs)
if return_patient_ids:
outputs.append(patient_ids)
return outputs

def evaluate(self, dataloader) -> Dict[str, float]:
"""Evaluates the model.
Expand Down

0 comments on commit 4ffa2cc

Please sign in to comment.