From d4f45e32bd3944d2de14439feb9ce51610df456e Mon Sep 17 00:00:00 2001 From: Hang Date: Tue, 29 Oct 2024 23:07:35 -0500 Subject: [PATCH 1/6] add mimic-iii dataset support --- pyhealth/data/__init__.py | 1 + pyhealth/data/data.py | 463 +++++++++++++++++++++++++++++++--- pyhealth/datasets/__init__.py | 2 +- pyhealth/datasets/mimic3.py | 118 ++++----- 4 files changed, 496 insertions(+), 88 deletions(-) diff --git a/pyhealth/data/__init__.py b/pyhealth/data/__init__.py index e1b29354..78b201d7 100755 --- a/pyhealth/data/__init__.py +++ b/pyhealth/data/__init__.py @@ -1,4 +1,5 @@ from .data import ( Event, Patient, + Visit, ) diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 6a061a24..a20b2444 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -1,50 +1,455 @@ -from dataclasses import dataclass, field +from collections import OrderedDict from datetime import datetime -from typing import Optional, List, Dict +from typing import Optional, List -@dataclass class Event: """Contains information about a single event. - An event can be anything from a diagnosis to a prescription or lab test - that happened for a patient at a specific time. + An event can be anything from a diagnosis to a prescription or a lab test + that happened in a visit of a patient at a specific time. Args: - type: type of the event (e.g., "diagnosis", "prescription", "lab_test"). - timestamp: timestamp of the event. - attr_dict: event attributes as a dictionary. + code: code of the event. E.g., "428.0" for congestive heart failure. + table: name of the table where the event is recorded. This corresponds + to the raw csv file name in the dataset. E.g., "DIAGNOSES_ICD". + vocabulary: vocabulary of the code. E.g., "ICD9CM" for ICD-9 diagnosis codes. + visit_id: unique identifier of the visit. + patient_id: unique identifier of the patient. + timestamp: timestamp of the event. Default is None. + **attr: optional attributes to add to the event as key=value pairs. + + Attributes: + attr_dict: Dict, dictionary of visit attributes. Each key is an attribute + name and each value is the attribute's value. + + Examples: + >>> from pyhealth.data import Event + >>> event = Event( + ... code="00069153041", + ... table="PRESCRIPTIONS", + ... vocabulary="NDC", + ... visit_id="v001", + ... patient_id="p001", + ... dosage="250mg", + ... ) + >>> event + Event with NDC code 00069153041 from table PRESCRIPTIONS + >>> event.attr_dict + {'dosage': '250mg'} """ - type: str - timestamp: Optional[datetime] = None - attr_dict: Dict[str, any] = field(default_factory=dict) + def __init__( + self, + code: str = None, + table: str = None, + vocabulary: str = None, + visit_id: str = None, + patient_id: str = None, + timestamp: Optional[datetime] = None, + **attr, + ): + assert timestamp is None or isinstance( + timestamp, datetime + ), "timestamp must be a datetime object" + self.code = code + self.table = table + self.vocabulary = vocabulary + self.visit_id = visit_id + self.patient_id = patient_id + self.timestamp = timestamp + self.attr_dict = dict() + self.attr_dict.update(attr) + + def __repr__(self): + return f"Event with {self.vocabulary} code {self.code} from table {self.table}" + + def __str__(self): + lines = list() + lines.append(f"Event from patient {self.patient_id} visit {self.visit_id}:") + lines.append(f"\t- Code: {self.code}") + lines.append(f"\t- Table: {self.table}") + lines.append(f"\t- Vocabulary: {self.vocabulary}") + lines.append(f"\t- Timestamp: {self.timestamp}") + for k, v in self.attr_dict.items(): + lines.append(f"\t- {k}: {v}") + return "\n".join(lines) + +class Visit: + """Contains information about a single visit. + + A visit is a period of time in which a patient is admitted to a hospital or + a specific department. Each visit is associated with a patient and contains + a list of different events. + + Args: + visit_id: unique identifier of the visit. + patient_id: unique identifier of the patient. + encounter_time: timestamp of visit's encounter. Default is None. + discharge_time: timestamp of visit's discharge. Default is None. + discharge_status: patient's status upon discharge. Default is None. + **attr: optional attributes to add to the visit as key=value pairs. + + Attributes: + attr_dict: Dict, dictionary of visit attributes. Each key is an attribute + name and each value is the attribute's value. + event_list_dict: Dict[str, List[Event]], dictionary of event lists. + Each key is a table name and each value is a list of events from that + table ordered by timestamp. + + Examples: + >>> from pyhealth.data import Event, Visit + >>> event = Event( + ... code="00069153041", + ... table="PRESCRIPTIONS", + ... vocabulary="NDC", + ... visit_id="v001", + ... patient_id="p001", + ... dosage="250mg", + ... ) + >>> visit = Visit( + ... visit_id="v001", + ... patient_id="p001", + ... ) + >>> visit.add_event(event) + >>> visit + Visit v001 from patient p001 with 1 events from tables ['PRESCRIPTIONS'] + >>> visit.available_tables + ['PRESCRIPTIONS'] + >>> visit.num_events + 1 + >>> visit.get_event_list('PRESCRIPTIONS') + [Event with NDC code 00069153041 from table PRESCRIPTIONS] + >>> visit.get_code_list('PRESCRIPTIONS') + ['00069153041'] + >>> patient.available_tables + ['PRESCRIPTIONS'] + >>> patient.get_visit_by_index(0) + Visit v001 from patient p001 with 1 events from tables ['PRESCRIPTIONS'] + >>> patient.get_visit_by_index(0).get_code_list(table="PRESCRIPTIONS") + ['00069153041'] + """ + + def __init__( + self, + visit_id: str, + patient_id: str, + encounter_time: Optional[datetime] = None, + discharge_time: Optional[datetime] = None, + discharge_status=None, + **attr, + ): + assert encounter_time is None or isinstance( + encounter_time, datetime + ), "encounter_time must be a datetime object" + assert discharge_time is None or isinstance( + discharge_time, datetime + ), "discharge_time must be a datetime object" + self.visit_id = visit_id + self.patient_id = patient_id + self.encounter_time = encounter_time + self.discharge_time = discharge_time + self.discharge_status = discharge_status + self.attr_dict = dict() + self.attr_dict.update(attr) + self.event_list_dict = dict() + + def add_event(self, event: Event) -> None: + """Adds an event to the visit. + + If the event's table is not in the visit's event list dictionary, it is + added as a new key. The event is then added to the list of events of + that table. + + Args: + event: event to add. + + Note: + As for now, there is no check on the order of the events. The new event + is simply appended to end of the list. + """ + assert event.visit_id == self.visit_id, "visit_id unmatched" + assert event.patient_id == self.patient_id, "patient_id unmatched" + table = event.table + if table not in self.event_list_dict: + self.event_list_dict[table] = list() + self.event_list_dict[table].append(event) + + def get_event_list(self, table: str) -> List[Event]: + """Returns a list of events from a specific table. + + If the table is not in the visit's event list dictionary, an empty list + is returned. + + Args: + table: name of the table. + + Returns: + List of events from the specified table. + + Note: + As for now, there is no check on the order of the events. The list of + events is simply returned as is. + """ + if table in self.event_list_dict: + return self.event_list_dict[table] + else: + return list() + + def get_code_list( + self, table: str, remove_duplicate: Optional[bool] = True + ) -> List[str]: + """Returns a list of codes from a specific table. + + If the table is not in the visit's event list dictionary, an empty list + is returned. + + Args: + table: name of the table. + remove_duplicate: whether to remove duplicate codes + (but keep the relative order). Default is True. + + Returns: + List of codes from the specified table. + + Note: + As for now, there is no check on the order of the codes. The list of + codes is simply returned as is. + """ + event_list = self.get_event_list(table) + code_list = [event.code for event in event_list] + if remove_duplicate: + # remove duplicate codes but keep the order + code_list = list(dict.fromkeys(code_list)) + return code_list + + def set_event_list(self, table: str, event_list: List[Event]) -> None: + """Sets the list of events from a specific table. + + This function will overwrite any existing list of events from + the specified table. + + Args: + table: name of the table. + event_list: list of events to set. + + Note: + As for now, there is no check on the order of the events. The list of + events is simply set as is. + """ + self.event_list_dict[table] = event_list + + @property + def available_tables(self) -> List[str]: + """Returns a list of available tables for the visit. + + Returns: + List of available tables. + """ + return list(self.event_list_dict.keys()) + + @property + def num_events(self) -> int: + """Returns the total number of events in the visit. + + Returns: + Total number of events. + """ + return sum([len(event_list) for event_list in self.event_list_dict.values()]) + + def __repr__(self): + return ( + f"Visit {self.visit_id} " + f"from patient {self.patient_id} " + f"with {self.num_events} events " + f"from tables {self.available_tables}" + ) + + def __str__(self): + lines = list() + lines.append( + f"Visit {self.visit_id} from patient {self.patient_id} " + f"with {self.num_events} events:" + ) + lines.append(f"\t- Encounter time: {self.encounter_time}") + lines.append(f"\t- Discharge time: {self.discharge_time}") + lines.append(f"\t- Discharge status: {self.discharge_status}") + lines.append(f"\t- Available tables: {self.available_tables}") + for k, v in self.attr_dict.items(): + lines.append(f"\t- {k}: {v}") + for table, event_list in self.event_list_dict.items(): + for event in event_list: + event_str = str(event).replace("\n", "\n\t") + lines.append(f"\t- {event_str}") + return "\n".join(lines) -@dataclass class Patient: - """Contains information about a single patient and their events. + """Contains information about a single patient. - A patient is a person who has a sequence of events over time, each associated - with specific health data. + A patient is a person who is admitted at least once to a hospital or + a specific department. Each patient is associated with a list of visits. Args: - patient_id: unique identifier for the patient. - attr_dict: patient attributes as a dictionary. - events: list of events for the patient. + patient_id: unique identifier of the patient. + birth_datetime: timestamp of patient's birth. Default is None. + death_datetime: timestamp of patient's death. Default is None. + gender: gender of the patient. Default is None. + ethnicity: ethnicity of the patient. Default is None. + **attr: optional attributes to add to the patient as key=value pairs. + + Attributes: + attr_dict: Dict, dictionary of patient attributes. Each key is an attribute + name and each value is the attribute's value. + visits: OrderedDict[str, Visit], an ordered dictionary of visits. Each key + is a visit_id and each value is a visit. + index_to_visit_id: Dict[int, str], dictionary that maps the index of a visit + in the visits list to the corresponding visit_id. + + Examples: + >>> from pyhealth.data import Event, Visit, Patient + >>> event = Event( + ... code="00069153041", + ... table="PRESCRIPTIONS", + ... vocabulary="NDC", + ... visit_id="v001", + ... patient_id="p001", + ... dosage="250mg", + ... ) + >>> visit = Visit( + ... visit_id="v001", + ... patient_id="p001", + ... ) + >>> visit.add_event(event) + >>> patient = Patient( + ... patient_id="p001", + ... ) + >>> patient.add_visit(visit) + >>> patient + Patient p001 with 1 visits """ - patient_id: str - attr_dict: Dict[str, any] = field(default_factory=dict) - events: List[Event] = field(default_factory=list) + + def __init__( + self, + patient_id: str, + birth_datetime: Optional[datetime] = None, + death_datetime: Optional[datetime] = None, + gender=None, + ethnicity=None, + **attr, + ): + self.patient_id = patient_id + self.birth_datetime = birth_datetime + self.death_datetime = death_datetime + self.gender = gender + self.ethnicity = ethnicity + self.attr_dict = dict() + self.attr_dict.update(attr) + self.visits = OrderedDict() + self.index_to_visit_id = dict() + + def add_visit(self, visit: Visit) -> None: + """Adds a visit to the patient. + + If the visit's visit_id is already in the patient's visits dictionary, + it will be overwritten by the new visit. + + Args: + visit: visit to add. + + Note: + As for now, there is no check on the order of the visits. The new visit + is simply added to the end of the ordered dictionary of visits. + """ + assert visit.patient_id == self.patient_id, "patient_id unmatched" + self.visits[visit.visit_id] = visit + # incrementing index + self.index_to_visit_id[len(self.visits) - 1] = visit.visit_id def add_event(self, event: Event) -> None: - """Adds an event to the patient's event sequence, maintaining order by event_time. + """Adds an event to the patient. - Events without a timestamp are placed at the end of the list. + If the event's visit_id is not in the patient's visits dictionary, this + function will raise KeyError. + + Args: + event: event to add. + + Note: + As for now, there is no check on the order of the events. The new event + is simply appended to the end of the list of events of the + corresponding visit. """ - self.events.append(event) - # Sort events, placing those with None timestamps at the end - self.events.sort(key=lambda e: (e.timestamp is None, e.timestamp)) + assert event.patient_id == self.patient_id, "patient_id unmatched" + visit_id = event.visit_id + if visit_id not in self.visits: + raise KeyError( + f"Visit with id {visit_id} not found in patient {self.patient_id}" + ) + self.get_visit_by_id(visit_id).add_event(event) + + def get_visit_by_id(self, visit_id: str) -> Visit: + """Returns a visit by visit_id. + + Args: + visit_id: unique identifier of the visit. + + Returns: + Visit with the given visit_id. + """ + return self.visits[visit_id] + + def get_visit_by_index(self, index: int) -> Visit: + """Returns a visit by its index. + + Args: + index: int, index of the visit to return. + + Returns: + Visit with the given index. + """ + if index not in self.index_to_visit_id: + raise IndexError( + f"Visit with index {index} not found in patient {self.patient_id}" + ) + visit_id = self.index_to_visit_id[index] + return self.get_visit_by_id(visit_id) + + @property + def available_tables(self) -> List[str]: + """Returns a list of available tables for the patient. + + Returns: + List of available tables. + """ + tables = [] + for visit in self: + tables.extend(visit.available_tables) + return list(set(tables)) + + def __len__(self): + """Returns the number of visits in the patient.""" + return len(self.visits) + + def __getitem__(self, index) -> Visit: + """Returns a visit by its index.""" + return self.get_visit_by_index(index) + + def __repr__(self): + return f"Patient {self.patient_id} with {len(self)} visits" + + def __str__(self): + lines = list() + # patient info + lines.append(f"Patient {self.patient_id} with {len(self)} visits:") + lines.append(f"\t- Birth datetime: {self.birth_datetime}") + lines.append(f"\t- Death datetime: {self.death_datetime}") + lines.append(f"\t- Gender: {self.gender}") + lines.append(f"\t- Ethnicity: {self.ethnicity}") + for k, v in self.attr_dict.items(): + lines.append(f"\t- {k}: {v}") + # visit info + for visit in self: + visit_str = str(visit).replace("\n", "\n\t") + lines.append(f"\t- {visit_str}") + return "\n".join(lines) - def get_events_by_type(self, event_type: str) -> List[Event]: - """Retrieve events of a specific type.""" - return [event for event in self.events if event.type == event_type] diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index d07b1ed0..30c38899 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -2,7 +2,7 @@ from .base_signal_dataset import BaseSignalDataset # from .cardiology import CardiologyDataset # from .eicu import eICUDataset -# from .mimic3 import MIMIC3Dataset +from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4Dataset # from .mimicextract import MIMICExtractDataset # from .omop import OMOPDataset diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 57912d3c..599a18e6 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -17,13 +17,13 @@ class MIMIC3Dataset(BaseEHRDataset): patients. The dataset is available at https://mimic.physionet.org/. The basic information is stored in the following tables: - - PATIENTS: defines a patient in the database, SUBJECT_ID. - - ADMISSIONS: defines a patient's hospital admission, HADM_ID. + - PATIENTS: defines a patient in the database, subject_id. + - ADMISSIONS: defines a patient's hospital admission, hadm_id. We further support the following tables: - DIAGNOSES_ICD: contains ICD-9 diagnoses (ICD9CM code) for patients. - PROCEDURES_ICD: contains ICD-9 procedures (ICD9PROC code) for patients. - - PRESCRIPTIONS: contains medication related order entries (NDC code) + - PRESCRIPTIONS: contains medication related order entries (ndc code) for patients. - LABEVENTS: contains laboratory measurements (MIMIC3_ITEMID code) for patients @@ -62,7 +62,7 @@ class MIMIC3Dataset(BaseEHRDataset): >>> dataset = MIMIC3Dataset( ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", ... tables=["DIAGNOSES_ICD", "PRESCRIPTIONS"], - ... code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, + ... code_mapping={"ndc": ("ATC", {"target_kwargs": {"level": 3}})}, ... ) >>> dataset.stat() >>> dataset.info() @@ -86,43 +86,45 @@ def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: # read patients table patients_df = pd.read_csv( os.path.join(self.root, "PATIENTS.csv"), - dtype={"SUBJECT_ID": str}, + dtype={"subject_id": str}, nrows=1000 if self.dev else None, ) # read admissions table admissions_df = pd.read_csv( os.path.join(self.root, "ADMISSIONS.csv"), - dtype={"SUBJECT_ID": str, "HADM_ID": str}, + dtype={"subject_id": str, "hadm_id": str}, ) # merge patient and admission tables - df = pd.merge(patients_df, admissions_df, on="SUBJECT_ID", how="inner") + # import pdb + # pdb.set_trace() + df = pd.merge(patients_df, admissions_df, on="subject_id", how="inner") # sort by admission and discharge time - df = df.sort_values(["SUBJECT_ID", "ADMITTIME", "DISCHTIME"], ascending=True) + df = df.sort_values(["subject_id", "admittime", "dischtime"], ascending=True) # group by patient - df_group = df.groupby("SUBJECT_ID") + df_group = df.groupby("subject_id") # parallel unit of basic information (per patient) def basic_unit(p_id, p_info): patient = Patient( patient_id=p_id, - birth_datetime=strptime(p_info["DOB"].values[0]), - death_datetime=strptime(p_info["DOD_HOSP"].values[0]), - gender=p_info["GENDER"].values[0], - ethnicity=p_info["ETHNICITY"].values[0], + birth_datetime=strptime(p_info["dob"].values[0]), + death_datetime=strptime(p_info["dod_hosp"].values[0]), + gender=p_info["gender"].values[0], + ethnicity=p_info["ethnicity"].values[0], ) # load visits - for v_id, v_info in p_info.groupby("HADM_ID"): + for v_id, v_info in p_info.groupby("hadm_id"): visit = Visit( visit_id=v_id, patient_id=p_id, - encounter_time=strptime(v_info["ADMITTIME"].values[0]), - discharge_time=strptime(v_info["DISCHTIME"].values[0]), - discharge_status=v_info["HOSPITAL_EXPIRE_FLAG"].values[0], - insurance=v_info["INSURANCE"].values[0], - language=v_info["LANGUAGE"].values[0], - religion=v_info["RELIGION"].values[0], - marital_status=v_info["MARITAL_STATUS"].values[0], - ethnicity=v_info["ETHNICITY"].values[0], + encounter_time=strptime(v_info["admittime"].values[0]), + discharge_time=strptime(v_info["dischtime"].values[0]), + discharge_status=v_info["hospital_expire_flag"].values[0], + insurance=v_info["insurance"].values[0], + language=v_info["language"].values[0], + religion=v_info["religion"].values[0], + marital_status=v_info["marital_status"].values[0], + ethnicity=v_info["ethnicity"].values[0], ) # add visit patient.add_visit(visit) @@ -130,7 +132,7 @@ def basic_unit(p_id, p_info): # parallel apply df_group = df_group.parallel_apply( - lambda x: basic_unit(x.SUBJECT_ID.unique()[0], x) + lambda x: basic_unit(x.subject_id.unique()[0], x) ) # summarize the results for pat_id, pat in df_group.items(): @@ -161,22 +163,22 @@ def parse_diagnoses_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patient # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"SUBJECT_ID": str, "HADM_ID": str, "ICD9_CODE": str}, + dtype={"subject_id": str, "hadm_id": str, "icd9_code": str}, ) # drop records of the other patients - df = df[df["SUBJECT_ID"].isin(patients.keys())] + df = df[df["subject_id"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "ICD9_CODE"]) + df = df.dropna(subset=["subject_id", "hadm_id", "icd9_code"]) # sort by sequence number (i.e., priority) - df = df.sort_values(["SUBJECT_ID", "HADM_ID", "SEQ_NUM"], ascending=True) + df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True) # group by patient and visit - group_df = df.groupby("SUBJECT_ID") + group_df = df.groupby("subject_id") # parallel unit of diagnosis (per patient) def diagnosis_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("HADM_ID"): - for code in v_info["ICD9_CODE"]: + for v_id, v_info in p_info.groupby("hadm_id"): + for code in v_info["icd9_code"]: event = Event( code=code, table=table, @@ -189,7 +191,7 @@ def diagnosis_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: diagnosis_unit(x.SUBJECT_ID.unique()[0], x) + lambda x: diagnosis_unit(x.subject_id.unique()[0], x) ) # summarize the results @@ -219,22 +221,22 @@ def parse_procedures_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patien # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"SUBJECT_ID": str, "HADM_ID": str, "ICD9_CODE": str}, + dtype={"subject_id": str, "hadm_id": str, "icd9_code": str}, ) # drop records of the other patients - df = df[df["SUBJECT_ID"].isin(patients.keys())] + df = df[df["subject_id"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "SEQ_NUM", "ICD9_CODE"]) + df = df.dropna(subset=["subject_id", "hadm_id", "seq_num", "icd9_code"]) # sort by sequence number (i.e., priority) - df = df.sort_values(["SUBJECT_ID", "HADM_ID", "SEQ_NUM"], ascending=True) + df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True) # group by patient and visit - group_df = df.groupby("SUBJECT_ID") + group_df = df.groupby("subject_id") # parallel unit of procedure (per patient) def procedure_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("HADM_ID"): - for code in v_info["ICD9_CODE"]: + for v_id, v_info in p_info.groupby("hadm_id"): + for code in v_info["icd9_code"]: event = Event( code=code, table=table, @@ -247,7 +249,7 @@ def procedure_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: procedure_unit(x.SUBJECT_ID.unique()[0], x) + lambda x: procedure_unit(x.subject_id.unique()[0], x) ) # summarize the results @@ -269,33 +271,33 @@ def parse_prescriptions(self, patients: Dict[str, Patient]) -> Dict[str, Patient The updated patients dict. """ table = "PRESCRIPTIONS" - self.code_vocs["drugs"] = "NDC" + self.code_vocs["drugs"] = "ndc" # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), low_memory=False, - dtype={"SUBJECT_ID": str, "HADM_ID": str, "NDC": str}, + dtype={"subject_id": str, "hadm_id": str, "ndc": str}, ) # drop records of the other patients - df = df[df["SUBJECT_ID"].isin(patients.keys())] + df = df[df["subject_id"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "NDC"]) + df = df.dropna(subset=["subject_id", "hadm_id", "ndc"]) # sort by start date and end date df = df.sort_values( - ["SUBJECT_ID", "HADM_ID", "STARTDATE", "ENDDATE"], ascending=True + ["subject_id", "hadm_id", "startdate", "enddate"], ascending=True ) # group by patient and visit - group_df = df.groupby("SUBJECT_ID") + group_df = df.groupby("subject_id") # parallel unit for prescription (per patient) def prescription_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("HADM_ID"): - for timestamp, code in zip(v_info["STARTDATE"], v_info["NDC"]): + for v_id, v_info in p_info.groupby("hadm_id"): + for timestamp, code in zip(v_info["startdate"], v_info["ndc"]): event = Event( code=code, table=table, - vocabulary="NDC", + vocabulary="ndc", visit_id=v_id, patient_id=p_id, timestamp=strptime(timestamp), @@ -305,7 +307,7 @@ def prescription_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: prescription_unit(x.SUBJECT_ID.unique()[0], x) + lambda x: prescription_unit(x.subject_id.unique()[0], x) ) # summarize the results @@ -331,21 +333,21 @@ def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"SUBJECT_ID": str, "HADM_ID": str, "ITEMID": str}, + dtype={"subject_id": str, "hadm_id": str, "ITEMID": str}, ) # drop records of the other patients - df = df[df["SUBJECT_ID"].isin(patients.keys())] + df = df[df["subject_id"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "ITEMID"]) + df = df.dropna(subset=["subject_id", "hadm_id", "ITEMID"]) # sort by charttime - df = df.sort_values(["SUBJECT_ID", "HADM_ID", "CHARTTIME"], ascending=True) + df = df.sort_values(["subject_id", "hadm_id", "CHARTTIME"], ascending=True) # group by patient and visit - group_df = df.groupby("SUBJECT_ID") + group_df = df.groupby("subject_id") # parallel unit for lab (per patient) def lab_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("HADM_ID"): + for v_id, v_info in p_info.groupby("hadm_id"): for timestamp, code in zip(v_info["CHARTTIME"], v_info["ITEMID"]): event = Event( code=code, @@ -360,7 +362,7 @@ def lab_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: lab_unit(x.SUBJECT_ID.unique()[0], x) + lambda x: lab_unit(x.subject_id.unique()[0], x) ) # summarize the results @@ -377,7 +379,7 @@ def lab_unit(p_id, p_info): "PRESCRIPTIONS", "LABEVENTS", ], - code_mapping={"NDC": "ATC"}, + code_mapping={"ndc": "ATC"}, dev=True, refresh_cache=True, ) @@ -388,7 +390,7 @@ def lab_unit(p_id, p_info): # root="/srv/local/data/physionet.org/files/mimiciii/1.4", # tables=["DIAGNOSES_ICD", "PRESCRIPTIONS"], # dev=True, - # code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, + # code_mapping={"ndc": ("ATC", {"target_kwargs": {"level": 3}})}, # refresh_cache=False, # ) # print(dataset.stat()) From d213ef8577801cd478767ef93ee7a37ea5104756 Mon Sep 17 00:00:00 2001 From: Hang Yu Date: Thu, 31 Oct 2024 08:26:17 +0000 Subject: [PATCH 2/6] added 48_ihm task support, next step to work on dicretization and imputation --- pyhealth/data/__init__.py | 2 +- pyhealth/datasets/mimic3.py | 18 +- pyhealth/datasets/mimic4.py | 293 +++++++--------------- pyhealth/tasks/__init__.py | 17 +- pyhealth/tasks/mortality_prediction.py | 62 ++++- pyhealth/tasks/mortality_prediction_v2.py | 49 +++- 6 files changed, 208 insertions(+), 233 deletions(-) diff --git a/pyhealth/data/__init__.py b/pyhealth/data/__init__.py index 78b201d7..0447e5db 100755 --- a/pyhealth/data/__init__.py +++ b/pyhealth/data/__init__.py @@ -2,4 +2,4 @@ Event, Patient, Visit, -) +) \ No newline at end of file diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 599a18e6..b93e25fa 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -1,3 +1,4 @@ +# bug fix note: the column names for mimic3 dataset should all be in lower cases instead of upper cases, fixed in this version import os from typing import Optional, List, Dict, Tuple, Union @@ -25,7 +26,7 @@ class MIMIC3Dataset(BaseEHRDataset): - PROCEDURES_ICD: contains ICD-9 procedures (ICD9PROC code) for patients. - PRESCRIPTIONS: contains medication related order entries (ndc code) for patients. - - LABEVENTS: contains laboratory measurements (MIMIC3_ITEMID code) + - LABEVENTS: contains laboratory measurements (MIMIC3_itemid code) for patients Args: @@ -95,8 +96,6 @@ def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: dtype={"subject_id": str, "hadm_id": str}, ) # merge patient and admission tables - # import pdb - # pdb.set_trace() df = pd.merge(patients_df, admissions_df, on="subject_id", how="inner") # sort by admission and discharge time df = df.sort_values(["subject_id", "admittime", "dischtime"], ascending=True) @@ -329,18 +328,19 @@ def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: The updated patients dict. """ table = "LABEVENTS" - self.code_vocs["labs"] = "MIMIC3_ITEMID" + self.code_vocs["labs"] = "MIMIC3_itemid" # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"subject_id": str, "hadm_id": str, "ITEMID": str}, + dtype={"subject_id": str, "hadm_id": str, "itemid": str}, ) # drop records of the other patients df = df[df["subject_id"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["subject_id", "hadm_id", "ITEMID"]) + # df = df.dropna(subset=["subject_id", "hadm_id"]) + df = df.dropna(subset=["subject_id", "hadm_id", "itemid"]) # sort by charttime - df = df.sort_values(["subject_id", "hadm_id", "CHARTTIME"], ascending=True) + df = df.sort_values(["subject_id", "hadm_id", "charttime"], ascending=True) # group by patient and visit group_df = df.groupby("subject_id") @@ -348,11 +348,11 @@ def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: def lab_unit(p_id, p_info): events = [] for v_id, v_info in p_info.groupby("hadm_id"): - for timestamp, code in zip(v_info["CHARTTIME"], v_info["ITEMID"]): + for timestamp, code in zip(v_info["charttime"], v_info["itemid"]): event = Event( code=code, table=table, - vocabulary="MIMIC3_ITEMID", + vocabulary="MIMIC3_itemid", visit_id=v_id, patient_id=p_id, timestamp=strptime(timestamp), diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index f48b7c02..927de210 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -1,20 +1,18 @@ +# The original code in 202410-sunlab-hackthon branch has some issue with dataset parsing. +# Using the code from main branch is a quick fix import os -import time -from datetime import timedelta -from typing import List, Dict +from typing import Optional, List, Dict, Union, Tuple import pandas as pd -from pandarallel import pandarallel -from pyhealth.data import Event, Patient -from pyhealth.datasets.base_dataset_v2 import BaseDataset +from pyhealth.data import Event, Visit, Patient +from pyhealth.datasets import BaseEHRDataset from pyhealth.datasets.utils import strptime - # TODO: add other tables -class MIMIC4Dataset(BaseDataset): +class MIMIC4Dataset(BaseEHRDataset): """Base dataset for MIMIC-IV dataset. The MIMIC-IV dataset is a large dataset of de-identified health records of ICU @@ -74,53 +72,6 @@ class MIMIC4Dataset(BaseDataset): >>> dataset.info() """ - def __init__( - self, - root: str, - dev=False, - tables: List[str] = None, - ): - self.dev = dev - self.tables = tables - super().__init__(root) - - def process(self) -> Dict[str, Patient]: - """Parses the tables in `self.tables` and return a dict of patients. - - Will be called in `self.__init__()` if cache file does not exist or - refresh_cache is True. - - This function will first call `self.parse_basic_info()` to parse the - basic patient information, and then call `self.parse_[table_name]()` to - parse the table with name `table_name`. Both `self.parse_basic_info()` and - `self.parse_[table_name]()` should be implemented in the subclass. - - Returns: - A dict mapping patient_id to `Patient` object. - """ - pandarallel.initialize(progress_bar=False) - - # patients is a dict of Patient objects indexed by patient_id - patients: Dict[str, Patient] = dict() - # process basic information (e.g., patients and visits) - tic = time.time() - patients = self.parse_basic_info(patients) - print( - "finish basic patient information parsing : {}s".format(time.time() - tic) - ) - # process clinical tables - for table in self.tables: - try: - # use lower case for function name - tic = time.time() - patients = getattr(self, f"parse_{table.lower()}")(patients) - print(f"finish parsing {table} : {time.time() - tic}s") - except AttributeError: - raise NotImplementedError( - f"Parser for table {table} is not implemented yet." - ) - return patients - def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: """Helper functions which parses patients and admissions tables. @@ -161,33 +112,27 @@ def basic_unit(p_id, p_info): anchor_year = int(p_info["anchor_year"].values[0]) anchor_age = int(p_info["anchor_age"].values[0]) birth_year = anchor_year - anchor_age - attr_dict = { - # no exact month, day, and time, use Jan 1st, 00:00:00 - "birth_datetime": strptime(str(birth_year)), - # no exact time, use 00:00:00 - "death_datetime": strptime(p_info["dod"].values[0]), - "gender": p_info["gender"].values[0], - "ethnicity": p_info["race"].values[0], - "anchor_year_group": p_info["anchor_year_group"].values[0], - } patient = Patient( patient_id=p_id, - attr_dict=attr_dict, + # no exact month, day, and time, use Jan 1st, 00:00:00 + birth_datetime=strptime(str(birth_year)), + # no exact time, use 00:00:00 + death_datetime=strptime(p_info["dod"].values[0]), + gender=p_info["gender"].values[0], + ethnicity=p_info["race"].values[0], + anchor_year_group=p_info["anchor_year_group"].values[0], ) - # load admissions + # load visits for v_id, v_info in p_info.groupby("hadm_id"): - attr_dict = { - "visit_id": v_id, - "discharge_time": strptime(v_info["dischtime"].values[0]), - "discharge_status": v_info["hospital_expire_flag"].values[0], - } - event = Event( - type="admissions", - timestamp=strptime(v_info["admittime"].values[0]), - attr_dict=attr_dict, + visit = Visit( + visit_id=v_id, + patient_id=p_id, + encounter_time=strptime(v_info["admittime"].values[0]), + discharge_time=strptime(v_info["dischtime"].values[0]), + discharge_status=v_info["hospital_expire_flag"].values[0], ) # add visit - patient.add_event(event) + patient.add_visit(visit) return patient # parallel apply @@ -228,17 +173,6 @@ def parse_diagnoses_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patient df = df.dropna(subset=["subject_id", "hadm_id", "icd_code", "icd_version"]) # sort by sequence number (i.e., priority) df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True) - # load admissions table - admissions_df = pd.read_csv( - os.path.join(self.root, "admissions.csv"), - dtype={"subject_id": str, "hadm_id": str}, - ) - # merge patients and admissions tables - df = df.merge( - admissions_df[["subject_id", "hadm_id", "dischtime"]], - on=["subject_id", "hadm_id"], - how="inner" - ) # group by patient and visit group_df = df.groupby("subject_id") @@ -248,17 +182,12 @@ def diagnosis_unit(p_id, p_info): # iterate over each patient and visit for v_id, v_info in p_info.groupby("hadm_id"): for code, version in zip(v_info["icd_code"], v_info["icd_version"]): - attr_dict = { - "code": code, - "vocabulary": f"ICD{version}CM", - "visit_id": v_id, - "patient_id": p_id, - } event = Event( - type=table, - timestamp=strptime(v_info["dischtime"].values[0]) - timedelta( - seconds=1), - attr_dict=attr_dict, + code=code, + table=table, + vocabulary=f"ICD{version}CM", + visit_id=v_id, + patient_id=p_id, ) events.append(event) return events @@ -300,17 +229,6 @@ def parse_procedures_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patien df = df.dropna(subset=["subject_id", "hadm_id", "icd_code", "icd_version"]) # sort by sequence number (i.e., priority) df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True) - # load admissions table - admissions_df = pd.read_csv( - os.path.join(self.root, "admissions.csv"), - dtype={"subject_id": str, "hadm_id": str}, - ) - # merge patients and admissions tables - df = df.merge( - admissions_df[["subject_id", "hadm_id", "dischtime"]], - on=["subject_id", "hadm_id"], - how="inner" - ) # group by patient and visit group_df = df.groupby("subject_id") @@ -319,17 +237,12 @@ def procedure_unit(p_id, p_info): events = [] for v_id, v_info in p_info.groupby("hadm_id"): for code, version in zip(v_info["icd_code"], v_info["icd_version"]): - attr_dict = { - "code": code, - "vocabulary": f"ICD{version}PROC", - "visit_id": v_id, - "patient_id": p_id, - } event = Event( - type=table, - timestamp=strptime(v_info["dischtime"].values[0]) - timedelta( - seconds=1), - attr_dict=attr_dict, + code=code, + table=table, + vocabulary=f"ICD{version}PROC", + visit_id=v_id, + patient_id=p_id, ) # update patients events.append(event) @@ -380,16 +293,13 @@ def prescription_unit(p_id, p_info): events = [] for v_id, v_info in p_info.groupby("hadm_id"): for timestamp, code in zip(v_info["starttime"], v_info["ndc"]): - attr_dict = { - "code": code, - "vocabulary": "NDC", - "visit_id": v_id, - "patient_id": p_id, - } event = Event( - type=table, + code=code, + table=table, + vocabulary="NDC", + visit_id=v_id, + patient_id=p_id, timestamp=strptime(timestamp), - attr_dict=attr_dict, ) # update patients events.append(event) @@ -437,16 +347,13 @@ def lab_unit(p_id, p_info): events = [] for v_id, v_info in p_info.groupby("hadm_id"): for timestamp, code in zip(v_info["charttime"], v_info["itemid"]): - attr_dict = { - "code": code, - "vocabulary": "MIMIC4_ITEMID", - "visit_id": v_id, - "patient_id": p_id, - } event = Event( - type=table, + code=code, + table=table, + vocabulary="MIMIC4_ITEMID", + visit_id=v_id, + patient_id=p_id, timestamp=strptime(timestamp), - attr_dict=attr_dict, ) events.append(event) return events @@ -460,85 +367,69 @@ def lab_unit(p_id, p_info): patients = self._add_events_to_patient_dict(patients, group_df) return patients - def _add_events_to_patient_dict( - self, - patient_dict: Dict[str, Patient], - group_df: pd.DataFrame, - ) -> Dict[str, Patient]: - """Helper function which adds the events column of a df.groupby object to the patient dict. - - Will be called at the end of each `self.parse_[table_name]()` function. + def parse_hcpcsevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: + """Helper function which parses hcpcsevents table. - Args: - patient_dict: a dict mapping patient_id to `Patient` object. - group_df: a df.groupby object, having two columns: patient_id and events. - - the patient_id column is the index of the patient - - the events column is a list of objects - - Returns: - The updated patient dict. - """ - for _, events in group_df.items(): - for event in events: - patient_dict = self._add_event_to_patient_dict(patient_dict, event) - return patient_dict - - @staticmethod - def _add_event_to_patient_dict( - patient_dict: Dict[str, Patient], - event: Event, - ) -> Dict[str, Patient]: - """Helper function which adds an event to the patient dict. - - Will be called in `self._add_events_to_patient_dict`. + Will be called in `self.parse_tables()` - Note that if the patient of the event is not in the patient dict, or the - visit of the event is not in the patient, this function will do nothing. + Docs: + - hcpcsevents: https://mimic.mit.edu/docs/iv/modules/hosp/hcpcsevents/ Args: - patient_dict: a dict mapping patient_id to `Patient` object. - event: an event to be added to the patient dict. + patients: a dict of `Patient` objects indexed by patient_id. Returns: - The updated patient dict. + The updated patients dict. + + Note: + MIMIC-IV does not provide specific timestamps in hcpcsevents + table, so we set it to None. """ - patient_id = event.attr_dict["patient_id"] - try: - patient_dict[patient_id].add_event(event) - except KeyError: - pass - return patient_dict - - def stat(self) -> str: - """Returns some statistics of the base dataset.""" - lines = list() - lines.append("") - lines.append(f"Statistics of base dataset (dev={self.dev}):") - lines.append(f"\t- Dataset: {self.dataset_name}") - lines.append(f"\t- Number of patients: {len(self.patients)}") - num_visits = [len(p.get_events_by_type("admissions")) for p in - self.patients.values()] - lines.append(f"\t- Number of visits: {sum(num_visits)}") - lines.append( - f"\t- Number of visits per patient: {sum(num_visits) / len(num_visits):.4f}" + table = "hcpcsevents" + # read table + df = pd.read_csv( + os.path.join(self.root, f"{table}.csv"), + dtype={"subject_id": str, "hadm_id": str, "hcpcs_cd": str}, ) - for table in self.tables: - num_events = [ - len(p.get_events_by_type(table)) for p in self.patients.values() - ] - lines.append( - f"\t- Number of events per patient in {table}: " - f"{sum(num_events) / len(num_events):.4f}" - ) - lines.append("") - print("\n".join(lines)) - return "\n".join(lines) + # drop rows with missing values + df = df.dropna(subset=["subject_id", "hadm_id", "hcpcs_cd"]) + # sort by sequence number (i.e., priority) + df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True) + # group by patient and visit + group_df = df.groupby("subject_id") + # parallel unit of hcpcsevents (per patient) + def hcpcsevents_unit(p_id, p_info): + events = [] + for v_id, v_info in p_info.groupby("hadm_id"): + for code in v_info["hcpcs_cd"]: + event = Event( + code=code, + table=table, + vocabulary="MIMIC4_HCPCS_CD", + visit_id=v_id, + patient_id=p_id, + ) + # update patients + events.append(event) + return events + + # parallel apply + group_df = group_df.parallel_apply( + lambda x: hcpcsevents_unit(x.subject_id.unique()[0], x) + ) + # summarize the results + patients = self._add_events_to_patient_dict(patients, group_df) + + return patients + if __name__ == "__main__": dataset = MIMIC4Dataset( root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", - tables=["diagnoses_icd", "procedures_icd"], - dev=True, + tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents", "hcpcsevents"], + code_mapping={"NDC": "ATC"}, + refresh_cache=False, ) dataset.stat() + dataset.info() \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 5f97c0ee..d55d6dbe 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -20,13 +20,14 @@ # length_of_stay_prediction_mimic4_fn, # length_of_stay_prediction_omop_fn, # ) -# from .mortality_prediction import ( -# mortality_prediction_eicu_fn, -# mortality_prediction_eicu_fn2, -# mortality_prediction_mimic3_fn, -# mortality_prediction_mimic4_fn, -# mortality_prediction_omop_fn, -# ) +from .mortality_prediction import ( + mimic3_48_ihm, + mortality_prediction_eicu_fn, + mortality_prediction_eicu_fn2, + mortality_prediction_mimic3_fn, + mortality_prediction_mimic4_fn, + mortality_prediction_omop_fn, +) # from .readmission_prediction import ( # readmission_prediction_eicu_fn, # readmission_prediction_eicu_fn2, @@ -43,4 +44,4 @@ from .covid19_cxr_classification import COVID19CXRClassification from .medical_transcriptions_classification import MedicalTranscriptionsClassification from .sleep_staging_v2 import SleepStagingSleepEDF -from .mortality_prediction_v2 import Mortality30DaysMIMIC4 \ No newline at end of file +from .mortality_prediction_v2 import Mortality30DaysMIMIC4, MIMIC3_48_IHM \ No newline at end of file diff --git a/pyhealth/tasks/mortality_prediction.py b/pyhealth/tasks/mortality_prediction.py index 3661540b..fa25ce2e 100644 --- a/pyhealth/tasks/mortality_prediction.py +++ b/pyhealth/tasks/mortality_prediction.py @@ -1,5 +1,62 @@ from pyhealth.data import Patient, Visit +from datetime import timedelta +def mimic3_48_ihm(patient): + """Processes a single patient for the 48hr in-hospital mortality task. + + Mortality prediction aims at predicting whether the patient will decease in the + next hospital visit based on the clinical information from current visit + (e.g., conditions and procedures). + + Args: + patient: a Patient object + + Returns: + samples: a list of samples, each sample is a dict with patient_id, + visit_id, and other task-specific attributes as key + + Note that we define the task as a binary classification task. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> mimic3_base = MIMIC3Dataset( + ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", + ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], + ... code_mapping={"ICD9CM": "CCSCM"}, + ... ) + >>> from pyhealth.tasks import mimic3_48_ihm + >>> mimic3_sample = mimic3_base.set_task(mimic3_48_ihm) + >>> mimic3_sample.samples[0] + [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 0}] + """ + samples = [] + + # we will drop the last visit + for i in range(len(patient)): + visit: Visit = patient[i] + + assert visit.discharge_status in [0, 1], f"Unexpected discharge status for Visit {visit}" + mortality_label = int(visit.discharge_status) + + # exclude the event happened after 48 hrs window on admission of the hospital + labevents = [event.code for event in visit.get_event_list(table="LABEVENTS") if event.timestamp < visit.encounter_time + timedelta(days=2)] + drugs = [event.code for event in visit.get_event_list(table="PRESCRIPTIONS") if event.timestamp < visit.encounter_time + timedelta(days=2)] + # exclude: visits without lab events, and drug code + if len(labevents) * len(drugs) == 0: + continue + # TODO: should also exclude visit with age < 18 + samples.append( + { + "visit_id": visit.visit_id, + "patient_id": patient.patient_id, + "labevents": [labevents], + "drugs": [drugs], + "period_length": 48, + "label": mortality_label, + } + ) + # no cohort selection + return samples def mortality_prediction_mimic3_fn(patient: Patient): """Processes a single patient for the mortality prediction task. @@ -257,7 +314,8 @@ def mortality_prediction_eicu_fn2(patient: Patient): "label": mortality_label, } ) - print(samples) + + # print(samples) # no cohort selection return samples @@ -379,4 +437,4 @@ def mortality_prediction_omop_fn(patient: Patient): ) sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_omop_fn) sample_dataset.stat() - print(sample_dataset.available_keys) + print(sample_dataset.available_keys) \ No newline at end of file diff --git a/pyhealth/tasks/mortality_prediction_v2.py b/pyhealth/tasks/mortality_prediction_v2.py index a821cefd..4df01370 100644 --- a/pyhealth/tasks/mortality_prediction_v2.py +++ b/pyhealth/tasks/mortality_prediction_v2.py @@ -43,18 +43,43 @@ def __call__(self, patient): } ] return samples + +@dataclass(frozen=True) +class MIMIC3_48_IHM(TaskTemplate): + task_name: str = "48_IHM" + input_schema: Dict[str, str] = field(default_factory=lambda: {"diagnoses": "sequence", "procedures": "sequence"}) + output_schema: Dict[str, str] = field(default_factory=lambda: {"mortality": "label"}) + def __call__(self, patient): + death_datetime = patient.attr_dict["death_datetime"] + diagnoses = patient.get_events_by_type("diagnoses_icd") + procedures = patient.get_events_by_type("procedures_icd") + mortality = 0 + if death_datetime is not None: + mortality = 1 + # remove events 30 days before death + diagnoses = [ + diag + for diag in diagnoses + if diag.timestamp <= death_datetime - timedelta(days=30) + ] + procedures = [ + proc + for proc in procedures + if proc.timestamp <= death_datetime - timedelta(days=30) + ] + diagnoses = [diag.attr_dict["code"] for diag in diagnoses] + procedures = [proc.attr_dict["code"] for proc in procedures] + if len(diagnoses) * len(procedures) == 0: + return [] -if __name__ == "__main__": - from pyhealth.datasets import MIMIC4Dataset - - dataset = MIMIC4Dataset( - root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", - tables=["procedures_icd"], - dev=True, - ) - task = Mortality30DaysMIMIC4() - samples = dataset.set_task(task) - print(samples[0]) - print(len(samples)) + samples = [ + { + "patient_id": patient.patient_id, + "diagnoses": diagnoses, + "procedures": procedures, + "mortality": mortality, + } + ] + return samples \ No newline at end of file From 8510c1797836f6d662ea4a5190b0641b946412b0 Mon Sep 17 00:00:00 2001 From: Hang Yu Date: Fri, 1 Nov 2024 00:13:49 +0000 Subject: [PATCH 3/6] implemented discretizer, injected discretization and imputation in 48_ihm task data preprocessing --- pyhealth/datasets/base_ehr_dataset.py | 2 + pyhealth/datasets/mimic3.py | 7 +- pyhealth/datasets/sample_dataset.py | 3 +- pyhealth/tasks/mortality_prediction.py | 63 +++++--- pyhealth/tasks/mortality_prediction_v2.py | 185 ++++++++++++++++++---- 5 files changed, 202 insertions(+), 58 deletions(-) diff --git a/pyhealth/datasets/base_ehr_dataset.py b/pyhealth/datasets/base_ehr_dataset.py index e5760b70..ac1e6665 100644 --- a/pyhealth/datasets/base_ehr_dataset.py +++ b/pyhealth/datasets/base_ehr_dataset.py @@ -415,6 +415,8 @@ def set_task( self.patients.items(), desc=f"Generating samples for {task_name}" ): samples.extend(task_fn(patient)) + # import pdb + # pdb.set_trace() sample_dataset = SampleEHRDataset( samples=samples, diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index b93e25fa..ad349f99 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -332,13 +332,13 @@ def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"subject_id": str, "hadm_id": str, "itemid": str}, + dtype={"subject_id": str, "hadm_id": str, "itemid": str, "valuenum": float}, ) # drop records of the other patients df = df[df["subject_id"].isin(patients.keys())] # drop rows with missing values # df = df.dropna(subset=["subject_id", "hadm_id"]) - df = df.dropna(subset=["subject_id", "hadm_id", "itemid"]) + df = df.dropna(subset=["subject_id", "hadm_id", "itemid", "valuenum"]) # sort by charttime df = df.sort_values(["subject_id", "hadm_id", "charttime"], ascending=True) # group by patient and visit @@ -348,7 +348,7 @@ def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: def lab_unit(p_id, p_info): events = [] for v_id, v_info in p_info.groupby("hadm_id"): - for timestamp, code in zip(v_info["charttime"], v_info["itemid"]): + for timestamp, code, valuenum in zip(v_info["charttime"], v_info["itemid"], v_info["valuenum"]): event = Event( code=code, table=table, @@ -356,6 +356,7 @@ def lab_unit(p_id, p_info): visit_id=v_id, patient_id=p_id, timestamp=strptime(timestamp), + valuenum=valuenum, ) events.append(event) return events diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 2949556d..e7c388ab 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,6 +1,7 @@ from collections import Counter from typing import Dict, List import pickle +import numpy as np from torch.utils.data import Dataset @@ -354,7 +355,7 @@ def _validate(self) -> Dict: """ types = set([type(v) for v in flattened_values]) assert ( - types == set([str]) or len(types.difference(set([int, float]))) == 0 + types == set([str]) or len(types.difference(set([int, float]))) == 0 or types == set([np.ndarray]) ), f"Key {key} has mixed or unsupported types ({types}) across samples" type_ = types.pop() """ diff --git a/pyhealth/tasks/mortality_prediction.py b/pyhealth/tasks/mortality_prediction.py index fa25ce2e..21e72c71 100644 --- a/pyhealth/tasks/mortality_prediction.py +++ b/pyhealth/tasks/mortality_prediction.py @@ -29,34 +29,47 @@ def mimic3_48_ihm(patient): >>> mimic3_sample.samples[0] [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 0}] """ - samples = [] + raise NotImplementedError("This task function is currently deprecated!") + # selected_labitem_ids = {'51221':0, '50971':1, '50983':2, '50912':3, '50902':4} # TODO: use LABEVENTS.csv group statistics to optimize this + # samples = [] - # we will drop the last visit - for i in range(len(patient)): - visit: Visit = patient[i] + # # we will drop the last visit + # for i in range(len(patient)): + # visit: Visit = patient[i] - assert visit.discharge_status in [0, 1], f"Unexpected discharge status for Visit {visit}" - mortality_label = int(visit.discharge_status) + # assert visit.discharge_status in [0, 1], f"Unexpected discharge status for Visit {visit}" + # mortality_label = int(visit.discharge_status) - # exclude the event happened after 48 hrs window on admission of the hospital - labevents = [event.code for event in visit.get_event_list(table="LABEVENTS") if event.timestamp < visit.encounter_time + timedelta(days=2)] - drugs = [event.code for event in visit.get_event_list(table="PRESCRIPTIONS") if event.timestamp < visit.encounter_time + timedelta(days=2)] - # exclude: visits without lab events, and drug code - if len(labevents) * len(drugs) == 0: - continue - # TODO: should also exclude visit with age < 18 - samples.append( - { - "visit_id": visit.visit_id, - "patient_id": patient.patient_id, - "labevents": [labevents], - "drugs": [drugs], - "period_length": 48, - "label": mortality_label, - } - ) - # no cohort selection - return samples + # # exclude the event happened after 48 hrs window on admission of the hospital + # # import pdb + # # pdb.set_trace() + # labevents = visit.get_event_list(table="LABEVENTS") + # x_ts = [[] for _ in range(len(selected_labitem_ids))] + # t_ts = [[] for _ in range(len(selected_labitem_ids))] + # for event in labevents: + # if event.timestamp > visit.encounter_time + timedelta(days=2): + # break + # if event.code in selected_labitem_ids: + # l = selected_labitem_ids[event.code] + # x_ts[l].append(event.attr_dict['valuenum']) + # t_ts[l].append(event.timestamp) + + # # exclude: visits without lab events + # if len(labevents) == 0: + # continue + # # TODO: should also exclude visit with age < 18 + # samples.append( + # { + # "visit_id": visit.visit_id, + # "patient_id": patient.patient_id, + # "x_ts": x_ts, + # "t_ts": t_ts, + # "period_length": 48, + # "label": mortality_label, + # } + # ) + # # no cohort selection + # return samples def mortality_prediction_mimic3_fn(patient: Patient): """Processes a single patient for the mortality prediction task. diff --git a/pyhealth/tasks/mortality_prediction_v2.py b/pyhealth/tasks/mortality_prediction_v2.py index 4df01370..239c2e99 100644 --- a/pyhealth/tasks/mortality_prediction_v2.py +++ b/pyhealth/tasks/mortality_prediction_v2.py @@ -2,6 +2,7 @@ from typing import Dict from datetime import timedelta from pyhealth.tasks.task_template import TaskTemplate +import numpy as np @dataclass(frozen=True) @@ -43,43 +44,169 @@ def __call__(self, patient): } ] return samples + +class Discretizer: + ''' + Discretizer class for MISTS feature extraction (https://arxiv.org/pdf/2210.12156) + Code modified from: https://github.com/XZhang97666/MultimodalMIMIC/blob/main/preprocessing.py by Hang Yu + ''' + def __init__(self, selected_channel_ids, normal_values, timestep=0.8, impute_strategy='zero', start_time='relative'): + ''' + Args: + timestep: interval span (hours) + TODO: other arguments documentation + ''' + self._selected_channel_ids = selected_channel_ids + self._timestep = timestep + self._start_time = start_time + self._impute_strategy = impute_strategy + self._normal_values = normal_values + + # for statistics + self._done_count = 0 + self._empty_bins_sum = 0 + self._unused_data_sum = 0 + + def transform(self, X, channel, timespan=None): + ''' + Args: + X: list of [timestamp, valuenum] + channel: the code of lab item + timespan: the timespan of the data we use + ''' + eps = 1e-6 + + t_ts, x_ts = zip(*X) + for i in range(len(t_ts) - 1): + assert t_ts[i] < t_ts[i+1] + timedelta(hours=eps) + + if self._start_time == 'relative': + first_time = t_ts[0] + elif self._start_time == 'zero': + raise NotImplementedError("start_time 'zero' not implemented yet") + else: + raise ValueError("start_time is invalid") + + if timespan is None: + max_hours = (max(t_ts) - first_time).total_seconds() / 3600 + else: + max_hours = timespan + + N_bins = int(max_hours / self._timestep + 1.0 - eps) + + data = np.zeros(shape=(N_bins,), dtype=float) + mask = np.zeros(shape=(N_bins,), dtype=int) + original_value = ["" for i in range(N_bins)] + total_data = 0 + unused_data = 0 + + for row in X: + t = (row[0] - first_time).total_seconds() / 3600 + if t > max_hours + eps: + continue + bin_id = int(t / self._timestep - eps) + assert 0 <= bin_id < N_bins + + total_data += 1 + if mask[bin_id] == 1: + unused_data += 1 + mask[bin_id] = 1 + data[bin_id] = row[1] + original_value[bin_id] = row[1] + + if self._impute_strategy not in ['zero', 'normal_value', 'previous', 'next']: + raise ValueError("impute strategy is invalid") + + if self._impute_strategy in ['normal_value', 'previous']: + prev_values = [] + for bin_id in range(N_bins): + if mask[bin_id] == 1: + prev_values.append(original_value[bin_id]) + continue + if self._impute_strategy == 'normal_value': + imputed_value = self._normal_values[channel] + if self._impute_strategy == 'previous': + if len(prev_values) == 0: + imputed_value = self._normal_values[channel] + else: + imputed_value = prev_values[-1] + data[bin_id] = imputed_value + # write(data, bin_id, channel, imputed_value) + + if self._impute_strategy == 'next': + prev_values = [] + for bin_id in range(N_bins-1, -1, -1): + if mask[bin_id] == 1: + prev_values.append(original_value[bin_id]) + continue + if len(prev_values) == 0: + imputed_value = self._normal_values[channel] + else: + imputed_value = prev_values[-1] + data[bin_id] = imputed_value + # write(data, bin_id, channel, imputed_value) + + # empty_bins = np.sum([1 - min(1, np.sum(mask[i, :])) for i in range(N_bins)]) + # self._done_count += 1 + # self._empty_bins_sum += empty_bins / (N_bins + eps) + # self._unused_data_sum += unused_data / (total_data + eps) + + return (data, mask) @dataclass(frozen=True) class MIMIC3_48_IHM(TaskTemplate): task_name: str = "48_IHM" input_schema: Dict[str, str] = field(default_factory=lambda: {"diagnoses": "sequence", "procedures": "sequence"}) output_schema: Dict[str, str] = field(default_factory=lambda: {"mortality": "label"}) + __name__: str = task_name + selected_labitem_ids = {'51221':0, '50971':1, '50983':2, '50912':3, '50902':4} # TODO: use LABEVENTS.csv group statistics to optimize this + normal_values = {'51221':42.4, '50971':4.3, '50983':139, '50912':0.7, '50902':97} + discretizer = Discretizer(selected_labitem_ids, normal_values) def __call__(self, patient): - death_datetime = patient.attr_dict["death_datetime"] - diagnoses = patient.get_events_by_type("diagnoses_icd") - procedures = patient.get_events_by_type("procedures_icd") - mortality = 0 - if death_datetime is not None: - mortality = 1 - # remove events 30 days before death - diagnoses = [ - diag - for diag in diagnoses - if diag.timestamp <= death_datetime - timedelta(days=30) - ] - procedures = [ - proc - for proc in procedures - if proc.timestamp <= death_datetime - timedelta(days=30) - ] - diagnoses = [diag.attr_dict["code"] for diag in diagnoses] - procedures = [proc.attr_dict["code"] for proc in procedures] + samples = [] - if len(diagnoses) * len(procedures) == 0: - return [] + # we will drop the last visit + for i in range(len(patient)): + visit: Visit = patient[i] - samples = [ - { - "patient_id": patient.patient_id, - "diagnoses": diagnoses, - "procedures": procedures, - "mortality": mortality, - } - ] + assert visit.discharge_status in [0, 1], f"Unexpected discharge status for Visit {visit}" + mortality_label = int(visit.discharge_status) + + # exclude the event happened after 48 hrs window on admission of the hospital + # import pdb + # pdb.set_trace() + labevents = visit.get_event_list(table="LABEVENTS") + # exclude: visits without lab events + if len(labevents) == 0: + continue + + Xs = [[] for _ in range(len(self.selected_labitem_ids))] + for event in labevents: + if event.timestamp > visit.encounter_time + timedelta(days=2): + break + if event.code in self.selected_labitem_ids: + l = self.selected_labitem_ids[event.code] + Xs[l].append([event.timestamp, event.attr_dict['valuenum']]) + + discretized_X, discretized_mask = [], [] + for code in self.selected_labitem_ids: + l = self.selected_labitem_ids[code] + x_ts, mask_ts = self.discretizer.transform(Xs[l], code, timespan=48) # TODO: add normalizer later + discretized_X.append(x_ts) + discretized_mask.append(mask_ts) + discretized_X = np.array(discretized_X) + discretized_mask = np.array(discretized_mask) + + # TODO: should also exclude visit with age < 18 + samples.append( + { + "visit_id": visit.visit_id, + "patient_id": patient.patient_id, + "X": discretized_X, + "mask": discretized_mask, + "label": mortality_label, + } + ) + # no cohort selection return samples \ No newline at end of file From 171c03b56471e69db005f39d3292f9369e676e04 Mon Sep 17 00:00:00 2001 From: Hang Yu Date: Fri, 1 Nov 2024 21:29:49 +0000 Subject: [PATCH 4/6] add examples; first runnable pipeline is available --- examples/MISTS_by_Hang.py | 56 +++++++++++ pyhealth/datasets/base_ehr_dataset.py | 2 - pyhealth/datasets/mimic3.py | 110 +++++++++++----------- pyhealth/tasks/mortality_prediction_v2.py | 21 ++--- 4 files changed, 120 insertions(+), 69 deletions(-) create mode 100644 examples/MISTS_by_Hang.py diff --git a/examples/MISTS_by_Hang.py b/examples/MISTS_by_Hang.py new file mode 100644 index 00000000..212503b4 --- /dev/null +++ b/examples/MISTS_by_Hang.py @@ -0,0 +1,56 @@ +# import sys +# sys.path.append('./PyHealth') +from pyhealth.datasets import MIMIC3Dataset + +mimic3_ds = MIMIC3Dataset( + root="https://storage.googleapis.com/pyhealth/mimiciii-demo/1.4/", + tables=["labevents"], + dev=True, +) + +from pyhealth.tasks import mimic3_48_ihm, MIMIC3_48_IHM + +samples = mimic3_ds.set_task(MIMIC3_48_IHM()) +from pyhealth.datasets import split_by_sample, get_dataloader + +# data split +train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.8, 0.1, 0.1]) + +# create dataloaders (they are object) +train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True) +val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False) +test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False) +from pyhealth.models import RNN + +model = RNN( + dataset=samples, + # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0] + feature_keys=["discretized_feature"], + label_key="mortality", + mode="multiclass", +) + +from pyhealth.trainer import Trainer + +trainer = Trainer( + model=model, + metrics=["accuracy", "f1_weighted"], # the metrics that we want to log + ) + +trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=20, + monitor="accuracy", + monitor_criterion="max", +) + +# option 2: use our pyhealth.metrics to evaluate +from pyhealth.metrics.multiclass import multiclass_metrics_fn + +y_true, y_prob, loss = trainer.inference(test_loader) +multiclass_metrics_fn( + y_true, + y_prob, + metrics=["f1_weighted", "f1_micro", "cohen_kappa"] +) \ No newline at end of file diff --git a/pyhealth/datasets/base_ehr_dataset.py b/pyhealth/datasets/base_ehr_dataset.py index ac1e6665..e5760b70 100644 --- a/pyhealth/datasets/base_ehr_dataset.py +++ b/pyhealth/datasets/base_ehr_dataset.py @@ -415,8 +415,6 @@ def set_task( self.patients.items(), desc=f"Generating samples for {task_name}" ): samples.extend(task_fn(patient)) - # import pdb - # pdb.set_trace() sample_dataset = SampleEHRDataset( samples=samples, diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index ad349f99..92e9adc0 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -18,15 +18,15 @@ class MIMIC3Dataset(BaseEHRDataset): patients. The dataset is available at https://mimic.physionet.org/. The basic information is stored in the following tables: - - PATIENTS: defines a patient in the database, subject_id. - - ADMISSIONS: defines a patient's hospital admission, hadm_id. + - PATIENTS: defines a patient in the database, SUBJECT_ID. + - ADMISSIONS: defines a patient's hospital admission, HADM_ID. We further support the following tables: - DIAGNOSES_ICD: contains ICD-9 diagnoses (ICD9CM code) for patients. - PROCEDURES_ICD: contains ICD-9 procedures (ICD9PROC code) for patients. - PRESCRIPTIONS: contains medication related order entries (ndc code) for patients. - - LABEVENTS: contains laboratory measurements (MIMIC3_itemid code) + - LABEVENTS: contains laboratory measurements (MIMIC3_ITEMID code) for patients Args: @@ -87,43 +87,43 @@ def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: # read patients table patients_df = pd.read_csv( os.path.join(self.root, "PATIENTS.csv"), - dtype={"subject_id": str}, + dtype={"SUBJECT_ID": str}, nrows=1000 if self.dev else None, ) # read admissions table admissions_df = pd.read_csv( os.path.join(self.root, "ADMISSIONS.csv"), - dtype={"subject_id": str, "hadm_id": str}, + dtype={"SUBJECT_ID": str, "HADM_ID": str}, ) # merge patient and admission tables - df = pd.merge(patients_df, admissions_df, on="subject_id", how="inner") + df = pd.merge(patients_df, admissions_df, on="SUBJECT_ID", how="inner") # sort by admission and discharge time - df = df.sort_values(["subject_id", "admittime", "dischtime"], ascending=True) + df = df.sort_values(["SUBJECT_ID", "ADMITTIME", "DISCHTIME"], ascending=True) # group by patient - df_group = df.groupby("subject_id") + df_group = df.groupby("SUBJECT_ID") # parallel unit of basic information (per patient) def basic_unit(p_id, p_info): patient = Patient( patient_id=p_id, - birth_datetime=strptime(p_info["dob"].values[0]), - death_datetime=strptime(p_info["dod_hosp"].values[0]), - gender=p_info["gender"].values[0], - ethnicity=p_info["ethnicity"].values[0], + birth_datetime=strptime(p_info["DOB"].values[0]), + death_datetime=strptime(p_info["DOD_HOSP"].values[0]), + gender=p_info["GENDER"].values[0], + ethnicity=p_info["ETHNICITY"].values[0], ) # load visits - for v_id, v_info in p_info.groupby("hadm_id"): + for v_id, v_info in p_info.groupby("HADM_ID"): visit = Visit( visit_id=v_id, patient_id=p_id, - encounter_time=strptime(v_info["admittime"].values[0]), - discharge_time=strptime(v_info["dischtime"].values[0]), - discharge_status=v_info["hospital_expire_flag"].values[0], - insurance=v_info["insurance"].values[0], - language=v_info["language"].values[0], - religion=v_info["religion"].values[0], - marital_status=v_info["marital_status"].values[0], - ethnicity=v_info["ethnicity"].values[0], + encounter_time=strptime(v_info["ADMITTIME"].values[0]), + discharge_time=strptime(v_info["DISCHTIME"].values[0]), + discharge_status=v_info["HOSPITAL_EXPIRE_FLAG"].values[0], + insurance=v_info["INSURANCE"].values[0], + language=v_info["LANGUAGE"].values[0], + religion=v_info["RELIGION"].values[0], + marital_status=v_info["MARITAL_STATUS"].values[0], + ethnicity=v_info["ETHNICITY"].values[0], ) # add visit patient.add_visit(visit) @@ -131,7 +131,7 @@ def basic_unit(p_id, p_info): # parallel apply df_group = df_group.parallel_apply( - lambda x: basic_unit(x.subject_id.unique()[0], x) + lambda x: basic_unit(x.SUBJECT_ID.unique()[0], x) ) # summarize the results for pat_id, pat in df_group.items(): @@ -162,21 +162,21 @@ def parse_diagnoses_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patient # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"subject_id": str, "hadm_id": str, "icd9_code": str}, + dtype={"SUBJECT_ID": str, "HADM_ID": str, "icd9_code": str}, ) # drop records of the other patients - df = df[df["subject_id"].isin(patients.keys())] + df = df[df["SUBJECT_ID"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["subject_id", "hadm_id", "icd9_code"]) + df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "icd9_code"]) # sort by sequence number (i.e., priority) - df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True) + df = df.sort_values(["SUBJECT_ID", "HADM_ID", "seq_num"], ascending=True) # group by patient and visit - group_df = df.groupby("subject_id") + group_df = df.groupby("SUBJECT_ID") # parallel unit of diagnosis (per patient) def diagnosis_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("hadm_id"): + for v_id, v_info in p_info.groupby("HADM_ID"): for code in v_info["icd9_code"]: event = Event( code=code, @@ -190,7 +190,7 @@ def diagnosis_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: diagnosis_unit(x.subject_id.unique()[0], x) + lambda x: diagnosis_unit(x.SUBJECT_ID.unique()[0], x) ) # summarize the results @@ -220,21 +220,21 @@ def parse_procedures_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patien # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"subject_id": str, "hadm_id": str, "icd9_code": str}, + dtype={"SUBJECT_ID": str, "HADM_ID": str, "icd9_code": str}, ) # drop records of the other patients - df = df[df["subject_id"].isin(patients.keys())] + df = df[df["SUBJECT_ID"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["subject_id", "hadm_id", "seq_num", "icd9_code"]) + df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "seq_num", "icd9_code"]) # sort by sequence number (i.e., priority) - df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True) + df = df.sort_values(["SUBJECT_ID", "HADM_ID", "seq_num"], ascending=True) # group by patient and visit - group_df = df.groupby("subject_id") + group_df = df.groupby("SUBJECT_ID") # parallel unit of procedure (per patient) def procedure_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("hadm_id"): + for v_id, v_info in p_info.groupby("HADM_ID"): for code in v_info["icd9_code"]: event = Event( code=code, @@ -248,7 +248,7 @@ def procedure_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: procedure_unit(x.subject_id.unique()[0], x) + lambda x: procedure_unit(x.SUBJECT_ID.unique()[0], x) ) # summarize the results @@ -275,23 +275,23 @@ def parse_prescriptions(self, patients: Dict[str, Patient]) -> Dict[str, Patient df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), low_memory=False, - dtype={"subject_id": str, "hadm_id": str, "ndc": str}, + dtype={"SUBJECT_ID": str, "HADM_ID": str, "ndc": str}, ) # drop records of the other patients - df = df[df["subject_id"].isin(patients.keys())] + df = df[df["SUBJECT_ID"].isin(patients.keys())] # drop rows with missing values - df = df.dropna(subset=["subject_id", "hadm_id", "ndc"]) + df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "ndc"]) # sort by start date and end date df = df.sort_values( - ["subject_id", "hadm_id", "startdate", "enddate"], ascending=True + ["SUBJECT_ID", "HADM_ID", "startdate", "enddate"], ascending=True ) # group by patient and visit - group_df = df.groupby("subject_id") + group_df = df.groupby("SUBJECT_ID") # parallel unit for prescription (per patient) def prescription_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("hadm_id"): + for v_id, v_info in p_info.groupby("HADM_ID"): for timestamp, code in zip(v_info["startdate"], v_info["ndc"]): event = Event( code=code, @@ -306,7 +306,7 @@ def prescription_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: prescription_unit(x.subject_id.unique()[0], x) + lambda x: prescription_unit(x.SUBJECT_ID.unique()[0], x) ) # summarize the results @@ -328,31 +328,31 @@ def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: The updated patients dict. """ table = "LABEVENTS" - self.code_vocs["labs"] = "MIMIC3_itemid" + self.code_vocs["labs"] = "MIMIC3_ITEMID" # read table df = pd.read_csv( os.path.join(self.root, f"{table}.csv"), - dtype={"subject_id": str, "hadm_id": str, "itemid": str, "valuenum": float}, + dtype={"SUBJECT_ID": str, "HADM_ID": str, "ITEMID": str, "VALUENUM": float}, ) # drop records of the other patients - df = df[df["subject_id"].isin(patients.keys())] + df = df[df["SUBJECT_ID"].isin(patients.keys())] # drop rows with missing values - # df = df.dropna(subset=["subject_id", "hadm_id"]) - df = df.dropna(subset=["subject_id", "hadm_id", "itemid", "valuenum"]) - # sort by charttime - df = df.sort_values(["subject_id", "hadm_id", "charttime"], ascending=True) + # df = df.dropna(subset=["SUBJECT_ID", "HADM_ID"]) + df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "ITEMID", "VALUENUM"]) + # sort by CHARTTIME + df = df.sort_values(["SUBJECT_ID", "HADM_ID", "CHARTTIME"], ascending=True) # group by patient and visit - group_df = df.groupby("subject_id") + group_df = df.groupby("SUBJECT_ID") # parallel unit for lab (per patient) def lab_unit(p_id, p_info): events = [] - for v_id, v_info in p_info.groupby("hadm_id"): - for timestamp, code, valuenum in zip(v_info["charttime"], v_info["itemid"], v_info["valuenum"]): + for v_id, v_info in p_info.groupby("HADM_ID"): + for timestamp, code, valuenum in zip(v_info["CHARTTIME"], v_info["ITEMID"], v_info["VALUENUM"]): event = Event( code=code, table=table, - vocabulary="MIMIC3_itemid", + vocabulary="MIMIC3_ITEMID", visit_id=v_id, patient_id=p_id, timestamp=strptime(timestamp), @@ -363,7 +363,7 @@ def lab_unit(p_id, p_info): # parallel apply group_df = group_df.parallel_apply( - lambda x: lab_unit(x.subject_id.unique()[0], x) + lambda x: lab_unit(x.SUBJECT_ID.unique()[0], x) ) # summarize the results diff --git a/pyhealth/tasks/mortality_prediction_v2.py b/pyhealth/tasks/mortality_prediction_v2.py index 239c2e99..61055f8c 100644 --- a/pyhealth/tasks/mortality_prediction_v2.py +++ b/pyhealth/tasks/mortality_prediction_v2.py @@ -151,12 +151,12 @@ def transform(self, X, channel, timespan=None): # self._empty_bins_sum += empty_bins / (N_bins + eps) # self._unused_data_sum += unused_data / (total_data + eps) - return (data, mask) + return (data.tolist(), mask.tolist()) @dataclass(frozen=True) class MIMIC3_48_IHM(TaskTemplate): task_name: str = "48_IHM" - input_schema: Dict[str, str] = field(default_factory=lambda: {"diagnoses": "sequence", "procedures": "sequence"}) + input_schema: Dict[str, str] = field(default_factory=lambda: {"discretized_features": "sequence"}) output_schema: Dict[str, str] = field(default_factory=lambda: {"mortality": "label"}) __name__: str = task_name selected_labitem_ids = {'51221':0, '50971':1, '50983':2, '50912':3, '50902':4} # TODO: use LABEVENTS.csv group statistics to optimize this @@ -174,11 +174,11 @@ def __call__(self, patient): mortality_label = int(visit.discharge_status) # exclude the event happened after 48 hrs window on admission of the hospital - # import pdb - # pdb.set_trace() labevents = visit.get_event_list(table="LABEVENTS") + end_timestamp = visit.encounter_time + timedelta(days=2) # exclude: visits without lab events - if len(labevents) == 0: + if len(labevents) == 0 or labevents[0].timestamp > end_timestamp or labevents[-1].timestamp < end_timestamp: + # if no event happens in this visit within the first 48 hrs or this visit is shorter than 48 hrs (2 days), we skip this visit continue Xs = [[] for _ in range(len(self.selected_labitem_ids))] @@ -194,18 +194,15 @@ def __call__(self, patient): l = self.selected_labitem_ids[code] x_ts, mask_ts = self.discretizer.transform(Xs[l], code, timespan=48) # TODO: add normalizer later discretized_X.append(x_ts) - discretized_mask.append(mask_ts) - discretized_X = np.array(discretized_X) - discretized_mask = np.array(discretized_mask) + discretized_mask.append(mask_ts) # not used so far # TODO: should also exclude visit with age < 18 samples.append( { - "visit_id": visit.visit_id, "patient_id": patient.patient_id, - "X": discretized_X, - "mask": discretized_mask, - "label": mortality_label, + "visit_id": visit.visit_id, + "discretized_feature": discretized_X, + "mortality": mortality_label, } ) # no cohort selection From ac334ef2124b581eded469a6fc20ccca1e8d21de Mon Sep 17 00:00:00 2001 From: Hang Yu Date: Thu, 14 Nov 2024 06:18:09 +0000 Subject: [PATCH 5/6] working on utde --- examples/MISTS_by_Hang.py | 5 +- pyhealth/models/utde.py | 277 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 2 deletions(-) create mode 100644 pyhealth/models/utde.py diff --git a/examples/MISTS_by_Hang.py b/examples/MISTS_by_Hang.py index 212503b4..0ede64a5 100644 --- a/examples/MISTS_by_Hang.py +++ b/examples/MISTS_by_Hang.py @@ -1,5 +1,6 @@ -# import sys -# sys.path.append('./PyHealth') +import sys +sys.path.append('./PyHealth') +import pyhealth from pyhealth.datasets import MIMIC3Dataset mimic3_ds = MIMIC3Dataset( diff --git a/pyhealth/models/utde.py b/pyhealth/models/utde.py new file mode 100644 index 00000000..d68eb97e --- /dev/null +++ b/pyhealth/models/utde.py @@ -0,0 +1,277 @@ +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.utils.rnn as rnn_utils + +from pyhealth.datasets import SampleEHRDataset +from pyhealth.models import BaseModel + +# VALID_OPERATION_LEVEL = ["visit", "event"] + +class multiTimeAttention(nn.Module): + """MultiTimeAttention module + borrowed from: https://github.com/XZhang97666/MultimodalMIMIC/blob/main/module.py + """ + + def __init__(self, input_dim, nhidden=16, + embed_time=16, num_heads=1): + super(multiTimeAttention, self).__init__() + assert embed_time % num_heads == 0 + self.embed_time = embed_time + self.embed_time_k = embed_time // num_heads + self.h = num_heads + self.dim = input_dim + self.nhidden = nhidden + self.linears = nn.ModuleList([nn.Linear(embed_time, embed_time), + nn.Linear(embed_time, embed_time), + nn.Linear(input_dim*num_heads, nhidden)]) + + def attention(self, query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + + dim = value.size(-1) + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) \ + / math.sqrt(d_k) + scores = scores.unsqueeze(-1).repeat_interleave(dim, dim=-1) + if mask is not None: + if len(mask.shape)==3: + mask=mask.unsqueeze(-1) + + scores = scores.masked_fill(mask.unsqueeze(-3) == 0, -10000) + p_attn = F.softmax(scores, dim = -2) + if dropout is not None: + p_attn=F.dropout(p_attn, p=dropout, training=self.training) + return torch.sum(p_attn*value.unsqueeze(-3), -2), p_attn + + + def forward(self, query, key, value, mask=None, dropout=0.1): + "Compute 'Scaled Dot Product Attention'" + batch, seq_len, dim = value.size() + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + value = value.unsqueeze(1) + query, key = [l(x).view(x.size(0), -1, self.h, self.embed_time_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key))] + x, _ = self.attention(query, key, value, mask, dropout) + x = x.transpose(1, 2).contiguous() \ + .view(batch, -1, self.h * dim) + return self.linears[-1](x) + + +class UTDE(nn.Module): + """Unified Time Discretization-based Embedding module. + + This layer wraps the PyTorch RNN layer with masking and dropout support. It is + used in the RNN model. But it can also be used as a standalone layer. + + Args: + input_size: input feature size. + hidden_size: hidden feature size. + rnn_type: type of rnn, one of "RNN", "LSTM", "GRU". Default is "GRU". + num_layers: number of recurrent layers. Default is 1. + dropout: dropout rate. If non-zero, introduces a Dropout layer before each + RNN layer. Default is 0.5. + bidirectional: whether to use bidirectional recurrent layers. If True, + a fully-connected layer is applied to the concatenation of the forward + and backward hidden states to reduce the dimension to hidden_size. + Default is False. + + Examples: + >>> from pyhealth.models import RNNLayer + >>> input = torch.randn(3, 128, 5) # [batch size, sequence len, input_size] + >>> layer = RNNLayer(5, 64) + >>> outputs, last_outputs = layer(input) + >>> outputs.shape + torch.Size([3, 128, 64]) + >>> last_outputs.shape + torch.Size([3, 64]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + embed_time: int, + rnn_type: str = "GRU", + num_layers: int = 1, + dropout: float = 0.5, + bidirectional: bool = False, + alpha: torch.tensor = None, + ): + super(UTDE, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.embed_time = embed_time + + self.periodic = nn.Linear(1, embed_time - 1) + self.linear = nn.Linear(1, 1) + self.time_attn_ts = multiTimeAttention(embed_time*2, embed_time, embed_time, 8) + self.proj1 = nn.Linear(embed_time, embed_time) + self.proj2 = nn.Linear(embed_time, embed_time) + self.out_layer= nn.Linear(embed_time, output_dim) + + def learn_time_embedding( + self, + T: torch.tensor, + ) + """Function to compute the time embedding for each features + + Args: + T: a tensor of shape [batch size, sequence len] from discretization step + + Returns: + time_embedding: a tensor of shape [batch size, sequence len, time embedding size], + time embedding tensor for each type of lab events for mTAND use + """ + T = T.unsqueeze(-1) + out1 = self.linear(T) + out2 = torch.sin(self.periodic(T)) + time_embedding = torch.cat([out1, out2], dim=-1) + return time_embedding + + def imputation( + self, + discretized_feature: torch.tensor, + ) -> torch.tensor: + """Getting imputation embedding + + Args: + discretized_feature: a tensor of shape [batch size, input channels, input size] from discretization step + + Returns: + imputation_embedding: a tensor of shape [batch size, output channels, hidden size] + """ + if len(discretized_feature.shape) == 2: + discretized_feature.unsqueeze(1) + return self.imputation_conv1d(discretized_feature) + + def mTAND( + self, + X: torch.tensor, + T: torch.tensor, + alpha: torch.tensor, + ) -> torch.tensor: + """Getting discretized multi-time attention embedding + + Args: + X: a tensor of shape [batch size, sequence length], + original values for different lab events at different time stamps + T: a tensor of shape [batch size, sequence length], + timestamps for the different lab events + + Returns: + imputation_embedding: a tensor of shape [batch size, output channels, hidden size] + """ + keys = self.learn_time_embedding(T) + query = self.learn_time_embedding(alpha) + + X_irg = torch.cat((X, X_mask), 2) + X_mask = torch.cat((X_mask, X_mask), 2) + + proj_X = self.time_attn_ts(query, keys, X_irg, X_mask).transpose(0, 1) + + last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.dropout, training=self.training)) + last_hs_proj += last_hs + output = self.out_layer(last_hs_proj) + return output + + def forward( + self, + X: torch.tensor, + T: torch.tensor, + alpha: torch.tensor, + discretized_feature: torch.tensor + ) -> Tuple[torch.tensor, torch.tensor]: + """Forward propagation. + + Args: + X: a tensor of shape [batch size, sequence length], + original values for different lab events at different time stamps + T: a tensor of shape [batch size, sequence length], + timestamps for the different lab events + + Returns: + imputation_embedding: a tensor of shape [batch size, output channels, hidden size] + mtand_embedding: a tensor of shape [batch size, otuput channels, hidden size] + """ + imputation_embedding = self.imputation(discretized_feature) + mtand_embedding = self.mTAND(X, T, alpha) + return imputation_embedding, mtand_embedding + +if __name__ == "__main__": + # from pyhealth.datasets import SampleEHRDataset + # + # samples = [ + # { + # "patient_id": "patient-0", + # "visit_id": "visit-0", + # # "single_vector": [1, 2, 3], + # "list_codes": ["505800458", "50580045810", "50580045811"], # NDC + # "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], + # "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 + # "list_list_vectors": [ + # [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], + # [[7.7, 8.5, 9.4]], + # ], + # "label": 1, + # }, + # { + # "patient_id": "patient-0", + # "visit_id": "visit-1", + # # "single_vector": [1, 5, 8], + # "list_codes": [ + # "55154191800", + # "551541928", + # "55154192800", + # "705182798", + # "70518279800", + # ], + # "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], + # "list_list_codes": [["A04A", "B035", "C129"]], + # "list_list_vectors": [ + # [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], + # ], + # "label": 0, + # }, + # ] + # + # # dataset + # dataset = SampleEHRDataset(samples=samples, dataset_name="test") + from pyhealth.datasets import MIMIC4Dataset + from pyhealth.tasks import Mortality30DaysMIMIC4 + + dataset = MIMIC4Dataset( + root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", + tables=["procedures_icd"], + dev=True, + ) + task = Mortality30DaysMIMIC4() + samples = dataset.set_task(task) + + # data loader + from pyhealth.datasets import get_dataloader + + train_loader = get_dataloader(samples, batch_size=2, shuffle=True) + + # model + model = RNN( + dataset=samples, + feature_keys=[ + "procedures", + ], + label_key="mortality", + mode="binary", + ) + + # data batch + data_batch = next(iter(train_loader)) + + # try the model + ret = model(**data_batch) + print(ret) + + # try loss backward + ret["loss"].backward() From 88e23f7308f859b6b9dc41e7d47236fee27b8e4b Mon Sep 17 00:00:00 2001 From: Hang Yu Date: Fri, 22 Nov 2024 00:02:21 +0000 Subject: [PATCH 6/6] add cnn 1d to utde; add original x_ts, t_ts to the dataset dictionary --- pyhealth/datasets/sample_dataset.py | 2 +- pyhealth/datasets/tuab.py | 2 +- pyhealth/datasets/tuev.py | 3 +-- pyhealth/models/__init__.py | 1 + pyhealth/models/utde.py | 5 ++-- pyhealth/tasks/mortality_prediction_v2.py | 32 ++++++++++++++--------- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index e7c388ab..e269b8ed 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -379,7 +379,7 @@ def _validate(self) -> Dict: # a list of vectors or a list of list of codes if type_ in [float, int]: lens = set([len(i) for s in self.samples for i in s[key]]) - assert len(lens) == 1, f"Key {key} has vectors of different lengths" + # assert len(lens) == 1, f"Key {key} has vectors of different lengths" #TODO: uncomment this later, currently need to comment this to skip dataset length check input_info[key] = {"type": type_, "dim": 2, "len": lens.pop()} else: # a list of list of codes diff --git a/pyhealth/datasets/tuab.py b/pyhealth/datasets/tuab.py index 33ea2807..8a7b874e 100644 --- a/pyhealth/datasets/tuab.py +++ b/pyhealth/datasets/tuab.py @@ -84,7 +84,7 @@ def process_EEG_data(self): patient_ids = list(set(patient_ids)) if self.dev: - patient_ids = patient_ids[:20] + patient_ids = patient_ids[:5] # get patient to record maps # - key: pid: diff --git a/pyhealth/datasets/tuev.py b/pyhealth/datasets/tuev.py index 77649aca..4009a019 100644 --- a/pyhealth/datasets/tuev.py +++ b/pyhealth/datasets/tuev.py @@ -66,8 +66,7 @@ def process_EEG_data(self): patient_ids = list(set(list(all_files.keys()))) if self.dev: - patient_ids = patient_ids[:20] - # print(patient_ids) + patient_ids = patient_ids[:30] # get patient to record maps # - key: pid: diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 780db43a..4e94abeb 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -24,3 +24,4 @@ from .gan import GAN from .torchvision_model import TorchvisionModel from .transformers_model import TransformersModel +from .utde import UTDE \ No newline at end of file diff --git a/pyhealth/models/utde.py b/pyhealth/models/utde.py index d68eb97e..e26a9b1b 100644 --- a/pyhealth/models/utde.py +++ b/pyhealth/models/utde.py @@ -5,7 +5,7 @@ import torch.nn.utils.rnn as rnn_utils from pyhealth.datasets import SampleEHRDataset -from pyhealth.models import BaseModel +from pyhealth.models import BaseModel, CNNLayer # VALID_OPERATION_LEVEL = ["visit", "event"] @@ -112,11 +112,12 @@ def __init__( self.proj1 = nn.Linear(embed_time, embed_time) self.proj2 = nn.Linear(embed_time, embed_time) self.out_layer= nn.Linear(embed_time, output_dim) + self.imputation_conv1d = CNNLayer(5, 64) # TODO: match the dimension with actual immputation def learn_time_embedding( self, T: torch.tensor, - ) + ): """Function to compute the time embedding for each features Args: diff --git a/pyhealth/tasks/mortality_prediction_v2.py b/pyhealth/tasks/mortality_prediction_v2.py index 61055f8c..af9e5886 100644 --- a/pyhealth/tasks/mortality_prediction_v2.py +++ b/pyhealth/tasks/mortality_prediction_v2.py @@ -67,18 +67,19 @@ def __init__(self, selected_channel_ids, normal_values, timestep=0.8, impute_str self._empty_bins_sum = 0 self._unused_data_sum = 0 - def transform(self, X, channel, timespan=None): + def transform(self, X, T, channel, timespan=None): ''' Args: - X: list of [timestamp, valuenum] + X: list of valuenum + T: list of timestamp channel: the code of lab item timespan: the timespan of the data we use ''' eps = 1e-6 - t_ts, x_ts = zip(*X) + t_ts, x_ts = T, X for i in range(len(t_ts) - 1): - assert t_ts[i] < t_ts[i+1] + timedelta(hours=eps) + assert t_ts[i] < t_ts[i+1] + eps if self._start_time == 'relative': first_time = t_ts[0] @@ -88,7 +89,7 @@ def transform(self, X, channel, timespan=None): raise ValueError("start_time is invalid") if timespan is None: - max_hours = (max(t_ts) - first_time).total_seconds() / 3600 + max_hours = max(t_ts) else: max_hours = timespan @@ -100,8 +101,8 @@ def transform(self, X, channel, timespan=None): total_data = 0 unused_data = 0 - for row in X: - t = (row[0] - first_time).total_seconds() / 3600 + for i in range(len(t_ts)): + t = t_ts[i] if t > max_hours + eps: continue bin_id = int(t / self._timestep - eps) @@ -111,8 +112,8 @@ def transform(self, X, channel, timespan=None): if mask[bin_id] == 1: unused_data += 1 mask[bin_id] = 1 - data[bin_id] = row[1] - original_value[bin_id] = row[1] + data[bin_id] = x_ts[i] + original_value[bin_id] = x_ts[i] if self._impute_strategy not in ['zero', 'normal_value', 'previous', 'next']: raise ValueError("impute strategy is invalid") @@ -181,18 +182,23 @@ def __call__(self, patient): # if no event happens in this visit within the first 48 hrs or this visit is shorter than 48 hrs (2 days), we skip this visit continue - Xs = [[] for _ in range(len(self.selected_labitem_ids))] + X_ts = [[] for _ in range(len(self.selected_labitem_ids))] + T_ts = [[] for _ in range(len(self.selected_labitem_ids))] for event in labevents: + if event.timestamp < visit.encounter_time: + # TODO: discuss with Zhenbang if this is desired, skip the lab events before the hospital admission + continue if event.timestamp > visit.encounter_time + timedelta(days=2): break if event.code in self.selected_labitem_ids: l = self.selected_labitem_ids[event.code] - Xs[l].append([event.timestamp, event.attr_dict['valuenum']]) + X_ts[l].append(event.attr_dict['valuenum']) + T_ts[l].append((event.timestamp - visit.encounter_time).total_seconds() / 3600) discretized_X, discretized_mask = [], [] for code in self.selected_labitem_ids: l = self.selected_labitem_ids[code] - x_ts, mask_ts = self.discretizer.transform(Xs[l], code, timespan=48) # TODO: add normalizer later + x_ts, mask_ts = self.discretizer.transform(X_ts[l], T_ts[l], code, timespan=48) # TODO: add normalizer later discretized_X.append(x_ts) discretized_mask.append(mask_ts) # not used so far @@ -202,6 +208,8 @@ def __call__(self, patient): "patient_id": patient.patient_id, "visit_id": visit.visit_id, "discretized_feature": discretized_X, + "x_ts": X_ts, + "t_ts": T_ts, "mortality": mortality_label, } )