Skip to content

Commit

Permalink
Dataviz/add slider to python api (#2)
Browse files Browse the repository at this point in the history
* feat: create a image slider dataviz

* cleanup the js code

* fix typos

* Removing JSON parse

* add GeneratedImages

* move code to show method instead

* bump version

* update notebook

* update readme with slider information

* remove bounding box from last image

* change image_slider.gif

* add a loading animation and fix a few minor issues

* fix typo

* try to display loading before image slider

* make loading enabled at initialization

* try display_id=42

* try to display both loading and slider at once

* try with update display

* Update README.md

* update notebook

Co-authored-by: André Batista <[email protected]>
  • Loading branch information
JoaoLages and andrewizbatista authored Sep 1, 2022
1 parent 34dfd41 commit bd37e42
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 130 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,8 @@ output['sample']

![](assets/corgi_eiffel_tower.png)

You can also check the image that the diffusion process generated in the end of each step.

For example, to see the image from step 10:
```python
output['all_samples_during_generation'][10]
```
![](assets/corgi_eiffel_tower_step10.png)
You can also check all the images that the diffusion process generated at the end of each step.
![](assets/image_slider.gif)

To check how a token in the input `prompt` influenced the generation, you can check the token attribution scores:
```python
Expand Down Expand Up @@ -132,7 +127,7 @@ The token attributions are now computed only for the area specified in the image
Check other functionalities and more implementation examples in [here](https://github.com/JoaoLages/diffusers-interpret/blob/main/notebooks/).

## Future Development
- [ ] Add interactive display of all the images that were generated in the diffusion process
- [x] ~~Add interactive display of all the images that were generated in the diffusion process~~
- [ ] Add interactive bounding-box and token attributions visualization
- [ ] Add unit tests
- [ ] Add example for `diffusers_interpret.LDMTextToImagePipelineExplainer`
Expand All @@ -141,3 +136,7 @@ Check other functionalities and more implementation examples in [here](https://g

## Contributing
Feel free to open an [Issue](https://github.com/JoaoLages/diffusers-interpret/issues) or create a [Pull Request](https://github.com/JoaoLages/diffusers-interpret/pulls) and let's get started 🚀

## Credits

A special thanks to [@andrewizbatista](https://github.com/andrewizbatista) for creating a great [image slider](https://github.com/JoaoLages/diffusers-interpret/pull/1) to show all the generated images during diffusion! 💪
Binary file added assets/image_slider.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
245 changes: 147 additions & 98 deletions notebooks/stable-diffusion-example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name='diffusers-interpret',
version='0.1.0',
version='0.2.0',
description='diffusers-interpret: model explainability for 🤗 Diffusers',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
48 changes: 25 additions & 23 deletions src/diffusers_interpret/explainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Dict, Any, Tuple, Set

from PIL import ImageDraw

import torch
from PIL import ImageDraw
from diffusers import DiffusionPipeline
from transformers import BatchEncoding, PreTrainedTokenizerBase

from diffusers_interpret.attribution import gradient_x_inputs_attribution
from diffusers_interpret.generated_images import GeneratedImages
from diffusers_interpret.utils import clean_token_from_prefixes_and_suffixes, transform_images_to_pil_format


class BasePipelineExplainer(ABC):
def __init__(self, pipe: DiffusionPipeline, verbose: bool = True):
def __init__(self, pipe: DiffusionPipeline, verbose: bool = True) -> None:
self.pipe = pipe
self.verbose = verbose

Expand Down Expand Up @@ -117,33 +117,35 @@ def __call__(
if self.verbose:
print("Done!")

# convert to PIL Image if requested
# also draw bounding box if requested
if output_type == "pil":
images_with_bounding_box = []
all_samples = output['all_samples_during_generation'] or [output['sample']]
for list_im in transform_images_to_pil_format(all_samples, self.pipe):
batch_images = []
for im in list_im:
if explanation_2d_bounding_box:
draw = ImageDraw.Draw(im)
draw.rectangle(explanation_2d_bounding_box, outline="red")
batch_images.append(im)
images_with_bounding_box.append(batch_images)

if output['all_samples_during_generation']:
output['all_samples_during_generation'] = images_with_bounding_box
output['sample'] = output['all_samples_during_generation'][-1]
else:
output['sample'] = images_with_bounding_box[-1]

if batch_size == 1:
# squash batch dimension
for k in ['sample', 'token_attributions', 'normalized_token_attributions']:
output[k] = output[k][0]
if output['all_samples_during_generation']:
output['all_samples_during_generation'] = [b[0] for b in output['all_samples_during_generation']]

# convert to PIL Image if requested
# also draw bounding box in the last image if requested
if output['all_samples_during_generation'] or output_type == "pil":
all_samples = GeneratedImages(
all_generated_images=output['all_samples_during_generation'] or [output['sample']],
pipe=self.pipe,
remove_batch_dimension=batch_size==1,
prepare_image_slider=bool(output['all_samples_during_generation'])
)
if output['all_samples_during_generation']:
output['all_samples_during_generation'] = all_samples
sample = output['all_samples_during_generation'][-1]
else:
sample = all_samples[-1]

if explanation_2d_bounding_box:
draw = ImageDraw.Draw(sample)
draw.rectangle(explanation_2d_bounding_box, outline="red")

if output_type == "pil":
output['sample'] = sample

return output

@property
Expand Down
145 changes: 145 additions & 0 deletions src/diffusers_interpret/generated_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import base64
import json
import os
from typing import List, Union

import torch
from IPython import display as d
from PIL.Image import Image
from diffusers import DiffusionPipeline

import diffusers_interpret
from diffusers_interpret.utils import transform_images_to_pil_format


class GeneratedImages:
def __init__(
self,
all_generated_images: List[torch.Tensor],
pipe: DiffusionPipeline,
remove_batch_dimension: bool = True,
prepare_image_slider: bool = True
) -> None:

assert all_generated_images, "Can't create GeneratedImages object with empty `all_generated_images`"

# Convert images to PIL and draw box if requested
self.images = []
for list_im in transform_images_to_pil_format(all_generated_images, pipe):
batch_images = []
for im in list_im:
batch_images.append(im)

if remove_batch_dimension:
self.images.extend(batch_images)
else:
self.images.append(batch_images)

self.loading_iframe = None
self.image_slider_iframe = None
if prepare_image_slider:
self.prepare_image_slider()

def prepare_image_slider(self) -> None:
"""
Creates auxiliary HTML file to be displayed in self.__repr__
"""

# Get data dir
image_slider_dir = os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider")

# Convert images to base64
json_payload = []
for i, image in enumerate(self.images):
image.save(f"{image_slider_dir}/to_delete.png")
with open(f"{image_slider_dir}/to_delete.png", "rb") as image_file:
json_payload.append(
{"image": "data:image/png;base64," + base64.b64encode(image_file.read()).decode('utf-8')}
)
os.remove(f"{image_slider_dir}/to_delete.png")

# get HTML file
with open(os.path.join(image_slider_dir, "index.html")) as fp:
html = fp.read()

# get CSS file
with open(os.path.join(image_slider_dir, "css/index.css")) as fp:
css = fp.read()

# get JS file
with open(os.path.join(image_slider_dir, "js/index.js")) as fp:
js = fp.read()

# replace CSS text in CSS file
html = html.replace("""<link href="css/index.css" rel="stylesheet" />""",
f"""<style type="text/css">\n{css}</style>""")

# replace JS text in HTML file
html = html.replace("""<script type="text/javascript" src="js/index.js"></script>""", ""
f"""<script type="text/javascript">\n{js}</script>""")

# get html with image slider JS call
index = html.find("<!-- INSERT STARTING SCRIPT HERE -->")
add = """
<script type="text/javascript">
((d) => {
const $body = d.querySelector("body");
if ($body) {
$body.addEventListener("INITIALIZE_IS_READY", ({ detail }) => {
const initialize = detail?.initialize ?? null;
if (initialize) initialize(%s);
});
}
})(document);
</script>
""" % json.dumps(json_payload)
html_with_image_slider = html[:index] + add + html[index:]

# save files and load IFrame to be displayed in self.__repr__
with open(os.path.join(image_slider_dir, "loading.html"), 'w') as fp:
fp.write(html)
with open(os.path.join(image_slider_dir, "final.html"), 'w') as fp:
fp.write(html_with_image_slider)

self.loading_iframe = d.IFrame(
os.path.relpath(
os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider", "loading.html"),
'.'
),
width="100%", height="400px"
)

self.image_slider_iframe = d.IFrame(
os.path.relpath(
os.path.join(os.path.dirname(diffusers_interpret.__file__), "dataviz", "image-slider", "final.html"),
'.'
),
width="100%", height="400px"
)

def __getitem__(self, item: int) -> Union[Image, List[Image]]:
return self.images[item]

def show(self, width: Union[str, int] = "100%", height: Union[str, int] = "400px") -> None:

if len(self.images) == 0:
raise Exception("`self.images` is an empty list, can't show any images")

if isinstance(self.images[0], list):
raise NotImplementedError("GeneratedImages.show visualization is not supported "
"when `self.images` is a list of lists of images")

if self.image_slider_iframe is None:
self.prepare_image_slider()

# display loading
self.loading_iframe.width = width
self.loading_iframe.height = height
display = d.display(self.loading_iframe, display_id=42)

# display image slider
self.image_slider_iframe.width = width
self.image_slider_iframe.height = height
display.update(self.image_slider_iframe)

0 comments on commit bd37e42

Please sign in to comment.