Skip to content

Commit

Permalink
Merge pull request #19 from imatge-upc/delete_nans
Browse files Browse the repository at this point in the history
Delete line that imputed data when loading
  • Loading branch information
CarlosHernandezP authored Feb 15, 2023
2 parents f2a8ad6 + 2072a69 commit 1076aed
Showing 1 changed file with 1 addition and 45 deletions.
46 changes: 1 addition & 45 deletions survlimepy/load_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,6 @@ def __init__(self, dataset_name: str = "veterans"):
self.feature_columns = ["bili", "stage", "riskscore", "trt"]
self.categorical_columns = []
self.df = pd.read_csv(udca_path)
elif dataset_name == "pbc":
self.feature_columns = [
"age",
"bili",
"chol",
"albumin",
"ast",
"ascites",
"copper",
"alk.phos",
"trig",
"platelet",
"protime",
"trt",
"sex",
"hepato",
"spiders",
"edema",
"stage",
]
self.categorical_columns = ["edema", "stage"]
self.df = pd.read_csv(pbc_path)
self.df["sex"] = [1 if x == "f" else 0 for x in self.df["sex"]]
elif dataset_name == "lung":
self.feature_columns = [
"inst",
Expand All @@ -80,28 +57,9 @@ def __init__(self, dataset_name: str = "veterans"):
# substract 1 to each value of the status column
self.df["status"] = [x - 1 for x in self.df["status"]]

elif dataset_name == "synthetic":
## TODO
pass
elif dataset_name == "heart":
self.feature_columns = [
"age",
"anaemia",
"creatinine_phosphokinase",
"diabetes",
"ejection_fraction",
"high_blood_pressure",
"platelets",
"serum_creatinine",
"serum_sodium",
"sex",
"smoking",
]
self.categorical_columns = []
self.df = pd.read_csv(heart_path)
else:
raise AssertionError(
f"The give name {dataset_name} was not found in [veterans, udca, pbc, lung]."
f"The give name {dataset_name} was not found in [veterans, udca, lung]."
)

def load_data(self) -> list([pd.DataFrame, np.ndarray]):
Expand All @@ -121,8 +79,6 @@ def load_data(self) -> list([pd.DataFrame, np.ndarray]):
times = [x[1] for x in y]
x = self.df[self.feature_columns]

x = x.fillna(value=x.median(numeric_only=True))

return x, events, times

def preprocess_datasets(
Expand Down

0 comments on commit 1076aed

Please sign in to comment.