From 97b120a4f08eb726f6ba33cc663d2aa54440dfd9 Mon Sep 17 00:00:00 2001 From: ycq091044 Date: Tue, 9 May 2023 20:51:05 -0500 Subject: [PATCH 1/5] upgrade from python 3.8 to 3.11 --- readthedocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/readthedocs.yml b/readthedocs.yml index c5e05719..3ac49695 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -8,7 +8,7 @@ formats: - htmlzip python: - version: "3.8" + version: "3.11" install: - method: pip path: . From cd728a98d1ffa615e40a870a1690c056e6b32c76 Mon Sep 17 00:00:00 2001 From: ycq091044 Date: Tue, 9 May 2023 20:54:01 -0500 Subject: [PATCH 2/5] upgrade from python 3.8 to 3.11 --- readthedocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/readthedocs.yml b/readthedocs.yml index 3ac49695..c5e05719 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -8,7 +8,7 @@ formats: - htmlzip python: - version: "3.11" + version: "3.8" install: - method: pip path: . From 7a73ef5523a239dc59fb70b615075fa536524ce0 Mon Sep 17 00:00:00 2001 From: ycq091044 Date: Tue, 9 May 2023 21:13:21 -0500 Subject: [PATCH 3/5] add docs for molerec and mimicextract dataset --- .../datasets/pyhealth.datasets.MIMICExtract.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 docs/api/datasets/pyhealth.datasets.MIMICExtract.rst diff --git a/docs/api/datasets/pyhealth.datasets.MIMICExtract.rst b/docs/api/datasets/pyhealth.datasets.MIMICExtract.rst new file mode 100644 index 00000000..d38b9cfc --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.MIMICExtract.rst @@ -0,0 +1,15 @@ +pyhealth.datasets.MIMICExtractDataset +=================================== + +The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. + +.. autoclass:: pyhealth.datasets.MIMICExtractDataset + :members: + :undoc-members: + :show-inheritance: + + + + + + \ No newline at end of file From e254710fafa0240c4a6102ad58c99170e1fed9d8 Mon Sep 17 00:00:00 2001 From: ycq091044 Date: Tue, 9 May 2023 21:21:25 -0500 Subject: [PATCH 4/5] add rst files to outline --- .../datasets/pyhealth.datasets.MIMICExtract.rst | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 docs/api/datasets/pyhealth.datasets.MIMICExtract.rst diff --git a/docs/api/datasets/pyhealth.datasets.MIMICExtract.rst b/docs/api/datasets/pyhealth.datasets.MIMICExtract.rst deleted file mode 100644 index d38b9cfc..00000000 --- a/docs/api/datasets/pyhealth.datasets.MIMICExtract.rst +++ /dev/null @@ -1,15 +0,0 @@ -pyhealth.datasets.MIMICExtractDataset -=================================== - -The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. - -.. autoclass:: pyhealth.datasets.MIMICExtractDataset - :members: - :undoc-members: - :show-inheritance: - - - - - - \ No newline at end of file From e7ef27ac9fcfc9906ee593a659ea11781bdaea21 Mon Sep 17 00:00:00 2001 From: Mitchell Hermon <75547238+mhermon@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:19:56 -0500 Subject: [PATCH 5/5] Base fairness metrics and example (#190) * upgrade from python 3.8 to 3.11 * upgrade from python 3.8 to 3.11 * restrict the urllib3 version * add docs for molerec and mimicextract dataset * add rst files to outline * fix illformed dependency * minor style update * fix omop doc typo (#166) * Base fairness metrics and example --------- Co-authored-by: ycq091044 Co-authored-by: Benjamin Danek --- .../pyhealth.datasets.OMOPDataset.rst | 2 +- examples/readmission_mimic3_fairness.py | 53 +++++++++++++++++ pyhealth/metrics/__init__.py | 1 + pyhealth/metrics/fairness.py | 52 +++++++++++++++++ pyhealth/metrics/fairness_utils/__init__.py | 2 + pyhealth/metrics/fairness_utils/group.py | 58 +++++++++++++++++++ pyhealth/metrics/fairness_utils/utils.py | 31 ++++++++++ pyhealth/trainer.py | 13 ++++- 8 files changed, 208 insertions(+), 4 deletions(-) create mode 100644 examples/readmission_mimic3_fairness.py create mode 100644 pyhealth/metrics/fairness.py create mode 100644 pyhealth/metrics/fairness_utils/__init__.py create mode 100644 pyhealth/metrics/fairness_utils/group.py create mode 100644 pyhealth/metrics/fairness_utils/utils.py diff --git a/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst b/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst index 1c140905..108f3001 100644 --- a/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst +++ b/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst @@ -1,7 +1,7 @@ pyhealth.datasets.OMOPDataset =================================== -We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. We it into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. +We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. The raw data is processed into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. .. autoclass:: pyhealth.datasets.OMOPDataset :members: diff --git a/examples/readmission_mimic3_fairness.py b/examples/readmission_mimic3_fairness.py new file mode 100644 index 00000000..d53b67c9 --- /dev/null +++ b/examples/readmission_mimic3_fairness.py @@ -0,0 +1,53 @@ +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.tasks import readmission_prediction_mimic3_fn +from pyhealth.datasets import split_by_patient, get_dataloader +from pyhealth.metrics import fairness_metrics_fn +from pyhealth.models import Transformer +from pyhealth.trainer import Trainer +from pyhealth.metrics.fairness_utils.utils import sensitive_attributes_from_patient_ids + +# STEP 1: load data +base_dataset = MIMIC3Dataset( + root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", + tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], +) +base_dataset.stat() + +# STEP 2: set task +sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) +sample_dataset.stat() + +train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) +train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) +val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) +test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) + +# STEP 3: define model +model = Transformer( + dataset=sample_dataset, + # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0] + feature_keys=["conditions", "procedures"], + label_key="label", + mode="binary", +) + +# STEP 4: define trainer +trainer = Trainer(model=model) +trainer.train( + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + epochs=3, + monitor="pr_auc", +) + +# STEP 5: inference, return patient_ids +y_true, y_prob, loss, patient_ids = trainer.inference(test_dataloader, return_patient_ids=True) + +# STEP 6: get sensitive attribute array from patient_ids +sensitive_attribute_array = sensitive_attributes_from_patient_ids(base_dataset, patient_ids, + 'gender', 'F') + +# STEP 7: use pyhealth.metrics to evaluate fairness +fairness_metrics = fairness_metrics_fn(y_true, y_prob, sensitive_attribute_array, + favorable_outcome=0) +print(fairness_metrics) \ No newline at end of file diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index c07c63e4..3a5c46a8 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -2,3 +2,4 @@ from .drug_recommendation import ddi_rate_score from .multiclass import multiclass_metrics_fn from .multilabel import multilabel_metrics_fn +from .fairness import fairness_metrics_fn diff --git a/pyhealth/metrics/fairness.py b/pyhealth/metrics/fairness.py new file mode 100644 index 00000000..f244750b --- /dev/null +++ b/pyhealth/metrics/fairness.py @@ -0,0 +1,52 @@ +from typing import Dict, List, Optional + +import numpy as np + +from pyhealth.metrics.fairness_utils import disparate_impact, statistical_parity_difference + +def fairness_metrics_fn( + y_true: np.ndarray, + y_prob: np.ndarray, + sensitive_attributes: np.ndarray, + favorable_outcome: int = 1, + metrics: Optional[List[str]] = None, + threshold: float = 0.5, +) -> Dict[str, float]: + """Computes metrics for binary classification. + + User can specify which metrics to compute by passing a list of metric names. + The accepted metric names are: + - disparate_impact: + - statistical_parity_difference: + + If no metrics are disparate_impact, and statistical_parity_difference are computed by default. + + Args: + y_true: True target values of shape (n_samples,). + y_prob: Predicted probabilities of shape (n_samples,). + sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. + favorable_outcome: Label value which is considered favorable (i.e. "positive"). + metrics: List of metrics to compute. Default is ["disparate_impact", "statistical_parity_difference"]. + threshold: Threshold for binary classification. Default is 0.5. + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + """ + if metrics is None: + metrics = ["disparate_impact", "statistical_parity_difference"] + + y_pred = y_prob.copy() + y_pred[y_pred >= threshold] = 1 + y_pred[y_pred < threshold] = 0 + + output = {} + for metric in metrics: + if metric == "disparate_impact": + output[metric] = disparate_impact(sensitive_attributes, y_pred, favorable_outcome) + elif metric == "statistical_parity_difference": + output[metric] = statistical_parity_difference(sensitive_attributes, y_pred, favorable_outcome) + else: + raise ValueError(f"Unknown metric for fairness: {metric}") + return output + diff --git a/pyhealth/metrics/fairness_utils/__init__.py b/pyhealth/metrics/fairness_utils/__init__.py new file mode 100644 index 00000000..6ec6537a --- /dev/null +++ b/pyhealth/metrics/fairness_utils/__init__.py @@ -0,0 +1,2 @@ +from .group import disparate_impact, statistical_parity_difference +from .utils import sensitive_attributes_from_patient_ids \ No newline at end of file diff --git a/pyhealth/metrics/fairness_utils/group.py b/pyhealth/metrics/fairness_utils/group.py new file mode 100644 index 00000000..c5543969 --- /dev/null +++ b/pyhealth/metrics/fairness_utils/group.py @@ -0,0 +1,58 @@ +import numpy as np + +""" +Notation: + - Protected group: P + - Unprotected group: U +""" + +def disparate_impact(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1, allow_zero_division = False, epsilon: float = 1e-8) -> float: + """ + Computes the disparate impact between the the protected and unprotected group. + + disparate_impact = P(y_pred = favorable_outcome | P) / P(y_pred = favorable_outcome | U) + Args: + sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. + y_pred: Predicted target values of shape (n_samples,). + favorable_outcome: Label value which is considered favorable (i.e. "positive"). + allow_zero_division: If True, use epsilon instead of 0 in the denominator if the denominator is 0. Otherwise, raise a ValueError. + Returns: + The disparate impact between the protected and unprotected group. + """ + + p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0]) + p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1]) + + if p_fav_unpr == 0: + if allow_zero_division: + p_fav_unpr = epsilon + else: + raise ValueError("Unprotected group has no instances with a favorable outcome. Disparate impact is undefined.") + + disparate_impact_value = p_fav_prot / p_fav_unpr + + return disparate_impact_value + +def statistical_parity_difference(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1) -> float: + """ + Computes the statistical parity difference between the the protected and unprotected group. + + statistical_parity_difference = P(y_pred = favorable_outcome | P) - P(y_pred = favorable_outcome | U) + Args: + sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. + y_pred: Predicted target values of shape (n_samples,). + favorable_outcome: Label value which is considered favorable (i.e. "positive"). + Returns: + The statistical parity difference between the protected and unprotected group. + """ + + p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0]) + p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1]) + + statistical_parity_difference_value = p_fav_prot - p_fav_unpr + + return statistical_parity_difference_value + + + + \ No newline at end of file diff --git a/pyhealth/metrics/fairness_utils/utils.py b/pyhealth/metrics/fairness_utils/utils.py new file mode 100644 index 00000000..b6696787 --- /dev/null +++ b/pyhealth/metrics/fairness_utils/utils.py @@ -0,0 +1,31 @@ + +from typing import List +import numpy as np + +from pyhealth.datasets import BaseEHRDataset + +def sensitive_attributes_from_patient_ids(dataset: BaseEHRDataset, + patient_ids: List[str], + sensitive_attribute: str, + protected_group: str) -> np.ndarray: + """ + Returns the desired sensitive attribute array from patient_ids. + + Args: + dataset: Dataset object. + patient_ids: List of patient IDs. + sensitive_attribute: Sensitive attribute to extract. + protected_group: Value of the protected group. + + Returns: + Sensitive attribute array of shape (n_samples,). + """ + + sensitive_attribute_array = np.zeros(len(patient_ids)) + for idx, patient_id in enumerate(patient_ids): + sensitive_attribute_value = getattr(dataset.patients[patient_id], sensitive_attribute) + if sensitive_attribute_value == protected_group: + sensitive_attribute_array[idx] = 1 + return sensitive_attribute_array + + \ No newline at end of file diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index fd290fc0..7c564988 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -243,7 +243,7 @@ def train( return - def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]: + def inference(self, dataloader, additional_outputs=None, return_patient_ids=False) -> Dict[str, float]: """Model inference. Args: @@ -256,10 +256,12 @@ def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]: y_prob_all: List of predicted probabilities. loss_mean: Mean loss over batches. additional_outputs (only if requested): Dict of additional results. + patient_ids (only if requested): List of patient ids in the same order as y_true_all/y_prob_all. """ loss_all = [] y_true_all = [] y_prob_all = [] + patient_ids = [] if additional_outputs is not None: additional_outputs = {k: [] for k in additional_outputs} for data in tqdm(dataloader, desc="Evaluation"): @@ -275,14 +277,19 @@ def inference(self, dataloader, additional_outputs=None) -> Dict[str, float]: if additional_outputs is not None: for key in additional_outputs.keys(): additional_outputs[key].append(output[key].cpu().numpy()) + if return_patient_ids: + patient_ids.extend(data["patient_id"]) loss_mean = sum(loss_all) / len(loss_all) y_true_all = np.concatenate(y_true_all, axis=0) y_prob_all = np.concatenate(y_prob_all, axis=0) + outputs = [y_true_all, y_prob_all, loss_mean] if additional_outputs is not None: additional_outputs = {key: np.concatenate(val) for key, val in additional_outputs.items()} - return y_true_all, y_prob_all, loss_mean, additional_outputs - return y_true_all, y_prob_all, loss_mean + outputs.append(additional_outputs) + if return_patient_ids: + outputs.append(patient_ids) + return outputs def evaluate(self, dataloader) -> Dict[str, float]: """Evaluates the model.