diff --git a/survlimepy/load_datasets.py b/survlimepy/load_datasets.py index 6688e6e..65e13e2 100644 --- a/survlimepy/load_datasets.py +++ b/survlimepy/load_datasets.py @@ -37,29 +37,6 @@ def __init__(self, dataset_name: str = "veterans"): self.feature_columns = ["bili", "stage", "riskscore", "trt"] self.categorical_columns = [] self.df = pd.read_csv(udca_path) - elif dataset_name == "pbc": - self.feature_columns = [ - "age", - "bili", - "chol", - "albumin", - "ast", - "ascites", - "copper", - "alk.phos", - "trig", - "platelet", - "protime", - "trt", - "sex", - "hepato", - "spiders", - "edema", - "stage", - ] - self.categorical_columns = ["edema", "stage"] - self.df = pd.read_csv(pbc_path) - self.df["sex"] = [1 if x == "f" else 0 for x in self.df["sex"]] elif dataset_name == "lung": self.feature_columns = [ "inst", @@ -80,28 +57,9 @@ def __init__(self, dataset_name: str = "veterans"): # substract 1 to each value of the status column self.df["status"] = [x - 1 for x in self.df["status"]] - elif dataset_name == "synthetic": - ## TODO - pass - elif dataset_name == "heart": - self.feature_columns = [ - "age", - "anaemia", - "creatinine_phosphokinase", - "diabetes", - "ejection_fraction", - "high_blood_pressure", - "platelets", - "serum_creatinine", - "serum_sodium", - "sex", - "smoking", - ] - self.categorical_columns = [] - self.df = pd.read_csv(heart_path) else: raise AssertionError( - f"The give name {dataset_name} was not found in [veterans, udca, pbc, lung]." + f"The give name {dataset_name} was not found in [veterans, udca, lung]." ) def load_data(self) -> list([pd.DataFrame, np.ndarray]): @@ -121,8 +79,6 @@ def load_data(self) -> list([pd.DataFrame, np.ndarray]): times = [x[1] for x in y] x = self.df[self.feature_columns] - x = x.fillna(value=x.median(numeric_only=True)) - return x, events, times def preprocess_datasets(