Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Stable Diffusion and MIMIC-CXR dataset for inference. #223

Open
wants to merge 3 commits into
base: zhenbang/f-image_text_support
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .medical_transriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4Dataset
from .mimiccxr_text import MIMICCXRDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
from .sample_dataset import SampleBaseDataset, SampleSignalDataset, SampleEHRDataset
Expand Down
64 changes: 64 additions & 0 deletions pyhealth/datasets/mimiccxr_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
from collections import Counter

import pandas as pd

from pyhealth.datasets.base_dataset_v2 import BaseDataset
from pyhealth.tasks.chest_xray_generation import ChestXrayGeneration


class MIMICCXRDataset(BaseDataset):
"""MIMIC-CXR data

Args:
dataset_name: name of the dataset.
root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.*
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.

Attributes:
root: root directory of the raw data (should contain many csv files).
dataset_name: name of the dataset. Default is the name of the class.
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.

Examples:
>>> dataset = MIMICCXRDataset(
root="/home/xucao2/xucao/PIEMedApp/checkpoints/mimic_cxr",
)
>>> print(dataset[0])
>>> dataset.stat()
>>> dataset.info()
"""

def process(self):
df = pd.read_csv(f"{self.root}/mimiccxr_text.csv", index_col=0)

# create patient dict
patients = {}
for index, row in df.iterrows():
patients[index] = row.to_dict()
return patients

def stat(self):
super().stat()
print(f"Number of samples: {len(self.patients)}")

@property
def default_task(self):
return ChestXrayGeneration()


if __name__ == "__main__":
dataset = MIMICCXRDataset(
root="/home/xucao2/xucao/PIEMedApp/checkpoints/mimic_cxr",
)
print(list(dataset.patients.items())[0])
dataset.stat()
samples = dataset.set_task()
print(samples[0])

1 change: 1 addition & 0 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .molerec import MoleRec, MoleRecLayer
from .torchvision_model import TorchvisionModel
from .transformers_model import TransformersModel
from .diffusion import DiffusionModel
106 changes: 106 additions & 0 deletions pyhealth/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import List, Dict

import torch
import torch.nn as nn
import torchvision
from diffusers import StableDiffusionPipeline

from pyhealth.datasets.sample_dataset_v2 import SampleDataset
from pyhealth.models import BaseModel


SUPPORTED_MODELS = [
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"IrohXu/stable-diffusion-mimic-cxr-v0.1"
]


class DiffusionModel(BaseModel):
"""Models from PyTorch's huggingface package.

This class is a wrapper for diffusion models from huggingface.

------------------------------Stable Diffusion-------------------------------------
Paper: Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B.
High-resolution image synthesis with latent diffusion models. CVPR 2022
-----------------------------------------------------------------------------------

Args:
dataset: the dataset to train the model. It is used to query certain
information such as the set of all tokens.
feature_keys: list of keys in samples to use as features, e.g., ["image"].
Only one feature is supported.
label_key: key in samples to use as label, e.g., "drugs".
mode: one of "binary", "multiclass", or "multilabel".
model_name: str, name of the model to use, e.g., "IrohXu/stable-diffusion-mimic-cxr-v0.1".
See SUPPORTED_MODELS in the source code for the full list.
model_config: dict, kwargs to pass to the model constructor,
e.g., {"weights": "DEFAULT"}. See the torchvision documentation for the
set of supported kwargs for each model.
-----------------------------------------------------------------------------------
"""

def __init__(
self,
dataset: SampleDataset,
feature_keys: List[str],
label_key: str,
mode: str,
model_name: str,
model_config: dict,
):
super(DiffusionModel, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode
)

self.model_name = model_name
self.model_config = model_config

self.model = StableDiffusionPipeline.from_pretrained(model_name, **model_config)

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation."""
# concat the info within one batch (batch, channel, length)
x = kwargs[self.feature_keys[0]]
out = self.model(x)
return out


if __name__ == "__main__":
from pyhealth.datasets.utils import get_dataloader
from pyhealth.datasets import MIMICCXRDataset

base_dataset = MIMICCXRDataset(
root="/home/xucao2/xucao/PIEMedApp/checkpoints/mimic_cxr",
)
sample_dataset = base_dataset.set_task()

def encode(sample):
return sample

sample_dataset.set_transform(encode)

train_loader = get_dataloader(sample_dataset, batch_size=1, shuffle=True)

device = "cuda"
model = DiffusionModel(
dataset=sample_dataset,
feature_keys=["text"],
label_key="text",
mode="multiclass",
model_name="IrohXu/stable-diffusion-mimic-cxr-v0.1",
model_config={"torch_dtype": torch.float16, "safety_checker": None},
)
model.model = model.model.to(device)

# data batch
data_batch = next(iter(train_loader))

# try the model
image = model(**data_batch).images[0]
image.save("result.png")

1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
)
from .covid19_cxr_classification import COVID19CXRClassification
from .medical_transcriptions_classification import MedicalTranscriptionsClassification
from .chest_xray_generation import ChestXrayGeneration
24 changes: 24 additions & 0 deletions pyhealth/tasks/chest_xray_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dataclasses import dataclass, field
from typing import Dict
import pandas as pd

from pyhealth.tasks.task_template import TaskTemplate


@dataclass(frozen=True)
class ChestXrayGeneration(TaskTemplate):
task_name: str = "ChestXrayGeneration"
input_schema: Dict[str, str] = field(default_factory=lambda: {"text": "text"})
output_schema: Dict[str, str] = field(default_factory=lambda: {"text": "text"})

def __call__(self, patient):
sample = {
"text": patient["report"],
}
return [sample]


if __name__ == "__main__":
task = ChestXrayGeneration()
print(task)
print(type(task))
8 changes: 7 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch>=1.8.0
torchvision>=0.14.0
rdkit>=2022.03.4
scikit-learn>=0.24.2
networkx>=2.6.3
Expand All @@ -7,4 +8,9 @@ pandarallel>=1.5.3
mne>=1.0.3
urllib3<=1.26.15
numpy
tqdm
tqdm
transformers==4.26.1
tokenizers==0.13.2
huggingface-hub==0.13.2
diffusers==0.15.0
clip==0.2.0