Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CheXpert Dataset and Its Associated Task #199

Draft
wants to merge 2 commits into
base: zhenbang/f-image_text_support
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pyhealth/datasets/base_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Optional, Dict

from tqdm import tqdm

from pyhealth.datasets.sample_dataset_v2 import SampleDataset
from pyhealth.tasks.task_template import TaskTemplate
import sys
sys.path.append('.')
from sample_dataset_v2 import SampleDataset# from pyhealth.datasets.sample_dataset_v2 import SampleDataset
sys.path.append('..')
from tasks.task_template import TaskTemplate #from pyhealth.tasks.task_template import TaskTemplate

logger = logging.getLogger(__name__)

Expand Down
106 changes: 106 additions & 0 deletions pyhealth/datasets/chexpert_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from collections import Counter
import pandas as pd
from tqdm import tqdm

from base_dataset_v2 import BaseDataset# from pyhealth.datasets.base_dataset_v2 import BaseDataset
from tasks.chexpert_v1_classification import CheXpertV1Classification

class CheXpertV1Dataset(BaseDataset):
"""Base image dataset for CheXpert Database

Dataset is available at https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2

**CheXpert v1 data:
-----------------------
- Train: 223414 images from 64540 patients
- Validation: 902 images from 700 patients

The CheXpert dataset consists of 14 labeled observations (pathology):
- No Finding, Enlarged Cardiomediastinum, Cardiomegaly, Lung Opacity, Lung Lesion, Edema, Consolidation, Pneumonia,
Atelectasis, Pneumothorax, Pleural Effusion, Pleural Other, Fracture, Support Devices
For each observation (pathology), there are 4 status:
- positive (1), negative (0), uncertain (-1), unmentioned (2)

Please cite the follwoing articles if you are using this dataset:
- Irvin, J., Rajpurkar, P., Ko, M., Yu, Y., Ciurea-Ilcus, S., Chute, C., Marklund, H., Haghgoo, B., Ball, R.,
Shpanskaya, K. and Seekins, J., 2019, July. Chexpert: A large chest radiograph dataset with uncertainty labels
and expert comparison. In Proceedings of the AAAI conference on artificial intelligence (Vol. 33, No. 01, pp. 590-597).

Args:
dataset_name: name of the dataset.
root: root directory of the raw data (The parent directory of /CheXpert-v1.0). *You can choose to use the path to Cassette portion or the Telemetry portion.*
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.

Attributes:
root: root directory of the raw data (should contain many csv files).
dataset_name: name of the dataset. Default is the name of the class.
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.

Examples:

>>> dataset = CheXpertV1Dataset(
root="/home/wuzijian1231/Datasets",
)
>>> print(dataset.patients[0])
>>> dataset.stat()
>>> samples = dataset.set_task()
>>> print(samples[0])
"""

def process(self):
# process and merge raw xlsx files from the dataset
df = pd.DataFrame(
pd.read_csv(f"{self.root}/CheXpert-v1.0/train.csv")
)
df.fillna(value=2.0, inplace=True) # positive (1), negative (0), uncertain (-1), unmentioned (2)
df["Path"] = df["Path"].apply(
lambda x: f"{self.root}/{x}"
)
df = df.drop(columns=["Sex", "Age", "Frontal/Lateral", "AP/PA"])
self.pathology = [c for c in df]
del self.pathology[0]
df_list= []
for p in self.pathology:
df_list.append(df[p])
self.df_label = pd.concat(df_list, axis=1)
labels = self.df_label.values.tolist()
df.columns = [col for col in df]
for path in tqdm(df.Path):
assert os.path.isfile(path)
# create patient dict
patients = {}
for index, row in tqdm(df.iterrows()):
patients[index] = {'path':row['Path'], 'label':labels[index]}
return patients

def stat(self):
super().stat()
print(f"Number of samples: {len(self.patients)}")
print(f"Number of Pathology: {len(self.pathology)}")
count = {}
for p in self.pathology:
cn = self.df_label[p]
count[p] = Counter(cn)
for p in self.pathology:
print(f"Class distribution - {p}: {count[p]}")

@property
def default_task(self):
return CheXpertV1Classification()

if __name__ == "__main__":
dataset = CheXpertV1Dataset(
root="/home/wuzijian1231/Datasets",
)
print(dataset.patients[0])
dataset.stat()
samples = dataset.set_task()
print(samples[0])

10 changes: 6 additions & 4 deletions pyhealth/datasets/covid19_cxr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from collections import Counter

import pandas as pd
import sys
sys.path.append('.')

from pyhealth.datasets.base_dataset_v2 import BaseDataset
from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification

from base_dataset_v2 import BaseDataset
from tasks.covid19_cxr_classification import COVID19CXRClassification# from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification

class COVID19CXRDataset(BaseDataset):
"""Base image dataset for COVID-19 Radiography Database
Expand Down Expand Up @@ -120,6 +121,7 @@ def process(self):
patients = {}
for index, row in df.iterrows():
patients[index] = row.to_dict()
breakpoint()
return patients

def stat(self):
Expand All @@ -136,7 +138,7 @@ def default_task(self):

if __name__ == "__main__":
dataset = COVID19CXRDataset(
root="/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset",
root="/home/wuzijian1231/Datasets/COVID-19_Radiography_Dataset"#"/srv/local/data/zw12/raw_data/covid19-radiography-database/COVID-19_Radiography_Dataset",
)
print(list(dataset.patients.items())[0])
dataset.stat()
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/datasets/sample_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.utils.data import Dataset

from pyhealth.datasets.featurizers import ImageFeaturizer, ValueFeaturizer
from featurizers import ImageFeaturizer, ValueFeaturizer# from pyhealth.datasets.featurizers import ImageFeaturizer, ValueFeaturizer


class SampleDataset(Dataset):
Expand Down
20 changes: 20 additions & 0 deletions pyhealth/tasks/chexpert_v1_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass, field
from typing import Dict

from tasks.task_template import TaskTemplate # from pyhealth.tasks import TaskTemplate


@dataclass(frozen=True)
class CheXpertV1Classification(TaskTemplate):
task_name: str = "CheXpertV1Classification"
input_schema: Dict[str, str] = field(default_factory=lambda: {"path": "image"})
output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"})

def __call__(self, patient):
return [patient]


if __name__ == "__main__":
task = CheXpertV1Classification()
print(task)
print(type(task))
1 change: 1 addition & 0 deletions pyhealth/tasks/covid19_cxr_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Dict


from pyhealth.tasks.task_template import TaskTemplate


Expand Down