Skip to content

Commit

Permalink
Stable Diffusion Reference Implementation (mlcommons#1519)
Browse files Browse the repository at this point in the history
* Stable Diffusion Reference Implementation

* Add calibration images

* Remove refiner from pipeline

* Minor fixes for SD

* Add SD variables to mlperf.conf

* Rename calibration file

* Add accuracy coco script + minor accuracy fixes
  • Loading branch information
pgmpablo157321 authored Dec 13, 2023
1 parent 330a8f3 commit 55bebbf
Show file tree
Hide file tree
Showing 23 changed files with 7,430 additions and 0 deletions.
500 changes: 500 additions & 0 deletions calibration/COCO-2014/coco_cal_captions_list.txt

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mlperf.conf
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ bert.*.performance_sample_count_override = 10833
dlrm.*.performance_sample_count_override = 204800
dlrm-v2.*.performance_sample_count_override = 204800
rnnt.*.performance_sample_count_override = 2513
stable-diffusion-xl.*.performance_sample_count_override = 5000
# set to 0 to let entire sample set to be performance sample
3d-unet.*.performance_sample_count_override = 0

Expand Down Expand Up @@ -54,6 +55,7 @@ dlrm.Server.target_latency = 60
dlrm-v2.Server.target_latency = 60
rnnt.Server.target_latency = 1000
gptj.Server.target_latency = 20000
stable-diffusion-xl.Server.target_latency = 20000

*.Offline.target_latency_percentile = 90
*.Offline.min_duration = 600000
Expand Down
82 changes: 82 additions & 0 deletions text_to_image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# MLPerf™ Inference Benchmarks for Text to Image

This is the reference implementation for MLPerf Inference text to image

## Supported Models

| model | accuracy | dataset | model link | model source | precision | notes |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| StableDiffusion | Torch | - | Coco2014 | - | [Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | fp16 | NCHW||

## Dataset

| Data | Description |
| ---- | ---- |
| Coco-2014 | We use a subset of 5000 images and captions of the coco 2014 validation dataset, so that there is exaclty one caption per image. The model takes as input the caption of the image and generates an image from it. The original images and the generated images are used to compute FID score. The caption and the generated images are used to compute the CLIP score. We provide a [script](tools/coco.py) to automatically download the dataset |
| Coco-2014 (calibration) | We use a subset of 100 images and captions of the coco 2014 training dataset, so that there is exaclty one caption per image. The subset was generated using this [script](tools/coco_generate_calibration.py). We provide the [caption ids](../calibration/COCO-2014/coco_cal_images_list.txt) and a [script](tools/coco_calibration.py) to download them. |


## Setup
Set the following helper variables
```bash
export ROOT=$PWD/inference
export SD_FOLDER=$PWD/inference/text_to_image
export LOADGEN_FOLDER=$PWD/inference/loadgen
```
### Clone the repository
**TEMPORARLY:**
```bash
git clone --recurse-submodules https://github.com/pgmpablo157321/inference.git --branch stable_diffusion_reference --depth 1
```
**KEEP FOR LATER:**
```bash
git clone --recurse-submodules https://github.com/mlcommmons/inference.git --depth 1
```
Finally copy the `mlperf.conf` file to the stable diffusion folder
```bash
cp $ROOT/mlperf.conf $SD_FOLDER
```

### Install requirements (only for running without using docker)
Install requirements:
```bash
cd SD_FOLDER
pip install -r requirements.txt
```
Install loadgen:
```bash
cd LOADGEN_FOLDER
CFLAGS="-std=c++14" python setup.py install
```

### Download dataset
```bash
cd $SD_FOLDER/tools
./download-coco-2014.sh -n <number_of_workers>
```
For debugging you can download only a part of all the images in the dataset
```bash
cd $SD_FOLDER/tools
./download-coco-2014.sh -m <max_number_of_images>
```
If the file [captions.tsv](coco2014/captions/captions.tsv) can be found in the script, it will be used to download the target dataset subset, otherwise it will be generated. We recommend you to have this file for consistency.

### Run the benchmark
#### Local run
```bash
python3 main.py --dataset "coco-1024" --dataset-path coco2014 --profile stable-diffusion-xl-pytorch [--model-path <TODO: provide model weights>] [--dtype <fp32, fp16 or bf16>] [--device <cuda or cpu>] [--time 600] [--scenario SingleStream]
```
#### Run using docker
```bash
cd $SD_FOLDER
# Build the container
docker build . -t sd_mlperf_inference
#Run the container
docker run --rm -it --gpus=all -v $SD_FOLDER:/workspace sd_mlperf_inference bash
```
Inside the container run the following:
```bash
python3 main.py --dataset "coco-1024" --dataset-path coco2014 --profile stable-diffusion-xl-pytorch [--model-path <TODO: provide model weights>] [--dtype <fp32, fp16 or bf16>] [--device <cuda or cpu>] [--time 600] [--scenario SingleStream]
```


21 changes: 21 additions & 0 deletions text_to_image/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
abstract backend class
"""


class Backend:
def __init__(self):
self.inputs = []
self.outputs = []

def version(self):
raise NotImplementedError("Backend:version")

def name(self):
raise NotImplementedError("Backend:name")

def load(self, model_path, inputs=None, outputs=None):
raise NotImplementedError("Backend:load")

def predict(self, feed):
raise NotImplementedError("Backend:predict")
28 changes: 28 additions & 0 deletions text_to_image/backend_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import backend


class BackendDebug(backend.Backend):
def __init__(self, image_size=[3, 1024, 1024], **kwargs):
super(BackendDebug, self).__init__()
self.image_size = image_size

def version(self):
return torch.__version__

def name(self):
return "debug-SUT"

def image_format(self):
return "NCHW"

def load(self):
return self

def predict(self, prompts):
images = []
with torch.no_grad():
for prompt in prompts:
image = torch.randn(self.image_size)
images.append(image)
return images
Loading

0 comments on commit 55bebbf

Please sign in to comment.