Skip to content

Commit

Permalink
Fix RISE algorithm for explain function (#1263)
Browse files Browse the repository at this point in the history
<!-- Contributing guide:
https://github.com/openvinotoolkit/datumaro/blob/develop/CONTRIBUTING.md
-->

### Summary

<!--
Resolves #111 and #222.
Depends on #1000 (for series of dependent commits).

This PR introduces this capability to make the project better in this
and that.

- Added this feature
- Removed that feature
- Fixed the problem #1234
-->

### How to test
<!-- Describe the testing procedure for reviewers, if changes are
not fully covered by unit tests or manual testing can be complicated.
-->

### Checklist
<!-- Put an 'x' in all the boxes that apply -->
- [x] I have added unit tests to cover my changes.​
- [x] I have added integration tests to cover my changes.​
- [x] I have added the description of my changes into
[CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​
- [ ] I have updated the
[documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs)
accordingly

### License

- [ ] I submit _my code changes_ under the same [MIT
License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE)
that covers the project.
  Feel free to contact the maintainers if that's a concern.
- [ ] I have updated the license header for each file (see an example
below).

```python
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
```
  • Loading branch information
wonjuleee authored Feb 8, 2024
1 parent 882bb1e commit 817b44e
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 247 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1245>)
- Enable image backend and color channel format to be selectable
(<https://github.com/openvinotoolkit/datumaro/pull/1246>)
- Enhance Datumaro data format detect() to be memory-bounded and performant
(<https://github.com/openvinotoolkit/datumaro/pull/1229>)
- Enhance RISE algortihm for explainable AI
(<https://github.com/openvinotoolkit/datumaro/pull/1263>)

### Bug fixes
- Fix wrong example of Datumaro dataset creation in document
Expand Down
10 changes: 3 additions & 7 deletions src/datumaro/cli/commands/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
rise_parser.add_argument(
"-s",
"--max-samples",
default=None,
default=100,
type=int,
help="Number of algorithm iterations (default: mask size ^ 2)",
)
Expand Down Expand Up @@ -203,13 +203,9 @@ def explain_command(args):

rise = RISE(
model,
max_samples=args.max_samples,
mask_width=args.mask_width,
mask_height=args.mask_height,
num_masks=args.max_samples,
mask_size=args.mask_width,
prob=args.prob,
iou_thresh=args.iou_thresh,
nms_thresh=args.nms_iou_thresh,
det_conf_thresh=args.det_conf_thresh,
batch_size=args.batch_size,
)

Expand Down
273 changes: 91 additions & 182 deletions src/datumaro/components/algorithms/rise.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
# Copyright (C) 2019-2020 Intel Corporation
# Copyright (C) 2019-2024 Intel Corporation
#
# SPDX-License-Identifier: MIT

# pylint: disable=unused-variable

from math import ceil

import cv2
import numpy as np

from datumaro.components.annotation import AnnotationType
from datumaro.util.annotation_util import nms
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.media import Image
from datumaro.util import take_by

__all__ = ["RISE"]


def _flatmatvec(mat):
return np.reshape(mat, (len(mat), -1))


def _expand(array, axis=None):
if axis is None:
axis = len(array.shape)
return np.expand_dims(array, axis=axis)


class RISE:
"""
Implements RISE: Randomized Input Sampling for
Expand All @@ -34,186 +25,104 @@ class RISE:
def __init__(
self,
model,
max_samples=None,
mask_width=7,
mask_height=7,
prob=0.5,
iou_thresh=0.9,
nms_thresh=0.0,
det_conf_thresh=0.0,
batch_size=1,
num_masks: int = 100,
mask_size: int = 7,
prob: float = 0.5,
batch_size: int = 1,
):
assert prob >= 0 and prob <= 1
self.model = model
self.max_samples = max_samples
self.mask_height = mask_height
self.mask_width = mask_width
self.num_masks = num_masks
self.mask_size = mask_size
self.prob = prob
self.iou_thresh = iou_thresh
self.nms_thresh = nms_thresh
self.det_conf_thresh = det_conf_thresh
self.batch_size = batch_size

@staticmethod
def split_outputs(annotations):
labels = []
bboxes = []
for r in annotations:
if r.type is AnnotationType.label:
labels.append(r)
elif r.type is AnnotationType.bbox:
bboxes.append(r)
return labels, bboxes

def normalize_hmaps(self, heatmaps, counts):
eps = np.finfo(heatmaps.dtype).eps
mhmaps = _flatmatvec(heatmaps)
mhmaps /= _expand(counts * self.prob + eps)
mhmaps -= _expand(np.min(mhmaps, axis=1))
mhmaps /= _expand(np.max(mhmaps, axis=1) + eps)
return np.reshape(mhmaps, heatmaps.shape)
def normalize_saliency(self, saliency):
normalized_saliency = np.empty_like(saliency)
for idx, sal in enumerate(saliency):
normalized_saliency[idx, ...] = (sal - np.min(sal)) / (np.max(sal) - np.min(sal))
return normalized_saliency

def apply(self, image, progressive=False):
import cv2
def generate_masks(self, image_size):
cell_size = np.ceil(np.array(image_size) / self.mask_size).astype(np.int8)
up_size = tuple([(self.mask_size + 1) * cs for cs in cell_size])

grid = np.random.rand(self.num_masks, self.mask_size, self.mask_size) < self.prob
grid = grid.astype("float32")

masks = np.empty((self.num_masks, *image_size))
for i in range(self.num_masks):
# Random shifts
x = np.random.randint(0, cell_size[0])
y = np.random.randint(0, cell_size[1])

# Linear upsampling and cropping
masks[i, ...] = cv2.resize(grid[i], up_size, interpolation=cv2.INTER_LINEAR)[
x : x + image_size[0], y : y + image_size[1]
]

return masks

def generate_masked_dataset(self, image, image_size, masks):
input_image = cv2.resize(image, image_size, interpolation=cv2.INTER_LINEAR)

items = []
for id, mask in enumerate(masks):
masked_image = np.expand_dims(mask, axis=-1) * input_image
items.append(
DatasetItem(
id=id,
media=Image.from_numpy(masked_image),
)
)
return Dataset.from_iterable(items)

def apply(self, image, progressive=False):
assert len(image.shape) in [2, 3], "Expected an input image in (H, W, C) format"
if len(image.shape) == 3:
assert image.shape[2] in [3, 4], "Expected BGR or BGRA input"
image = image[:, :, :3].astype(np.float32)

model = self.model
iou_thresh = self.iou_thresh

image_size = np.array((image.shape[:2]))
mask_size = np.array((self.mask_height, self.mask_width))
cell_size = np.ceil(image_size / mask_size)
upsampled_size = np.ceil((mask_size + 1) * cell_size)

rng = lambda shape=None: np.random.rand(*shape)
samples = np.prod(image_size)
if self.max_samples is not None:
samples = min(self.max_samples, samples)
batch_size = self.batch_size

# model is expected to get NxCxHxW shaped input tensor
pred = next(iter(model.infer(_expand(np.transpose(image, (2, 0, 1)), 0))))
result = model.postprocess(pred, None)
result_labels, result_bboxes = self.split_outputs(result)
if 0 < self.det_conf_thresh:
result_bboxes = [
b for b in result_bboxes if self.det_conf_thresh <= b.attributes["score"]
]
if 0 < self.nms_thresh:
result_bboxes = nms(result_bboxes, self.nms_thresh)

predicted_labels = set()
if len(result_labels) != 0:
predicted_label = max(result_labels, key=lambda r: r.attributes["score"]).label
predicted_labels.add(predicted_label)
if len(result_bboxes) != 0:
for bbox in result_bboxes:
predicted_labels.add(bbox.label)
predicted_labels = {label: idx for idx, label in enumerate(predicted_labels)}

predicted_bboxes = result_bboxes

heatmaps_count = len(predicted_labels) + len(predicted_bboxes)
heatmaps = np.zeros((heatmaps_count, *image_size), dtype=np.float32)
total_counts = np.zeros(heatmaps_count, dtype=np.int32)
confs = np.zeros(heatmaps_count, dtype=np.float32)

heatmap_id = 0

# label_heatmaps = None
label_total_counts = None
label_confs = None
if len(predicted_labels) != 0:
step = len(predicted_labels)
# label_heatmaps = heatmaps[heatmap_id : heatmap_id + step]
label_total_counts = total_counts[heatmap_id : heatmap_id + step]
label_confs = confs[heatmap_id : heatmap_id + step]
heatmap_id += step

# bbox_heatmaps = None
bbox_total_counts = None
bbox_confs = None
if len(predicted_bboxes) != 0:
step = len(predicted_bboxes)
# bbox_heatmaps = heatmaps[heatmap_id : heatmap_id + step]
bbox_total_counts = total_counts[heatmap_id : heatmap_id + step]
bbox_confs = confs[heatmap_id : heatmap_id + step]
heatmap_id += step

ups_mask = np.empty(upsampled_size.astype(int), dtype=np.float32)
masks = np.empty((batch_size, *image_size), dtype=np.float32)

full_batch_inputs = np.empty((batch_size, *image.shape), dtype=np.float32)
current_heatmaps = np.empty_like(heatmaps)
for b in range(ceil(samples / batch_size)):
batch_pos = b * batch_size
current_batch_size = min(samples - batch_pos, batch_size)

batch_masks = masks[:current_batch_size]
for i in range(current_batch_size):
mask = (rng(mask_size) < self.prob).astype(np.float32)
cv2.resize(mask, (int(upsampled_size[1]), int(upsampled_size[0])), ups_mask)

offsets = np.round(rng((2,)) * cell_size)
mask = ups_mask[
int(offsets[0]) : int(image_size[0] + offsets[0]),
int(offsets[1]) : int(image_size[1] + offsets[1]),
]
batch_masks[i] = mask

batch_inputs = full_batch_inputs[:current_batch_size]
np.multiply(_expand(batch_masks), _expand(image, 0), out=batch_inputs)

preds = model.infer(np.transpose(batch_inputs, (0, 3, 1, 2)))
results = [model.postprocess(pred, None) for pred in preds]
for mask, result in zip(batch_masks, results):
result_labels, result_bboxes = self.split_outputs(result)

confs.fill(0)
if len(predicted_labels) != 0:
for r in result_labels:
idx = predicted_labels.get(r.label, None)
if idx is not None:
label_total_counts[idx] += 1
label_confs[idx] += r.attributes["score"]
for r in result_bboxes:
idx = predicted_labels.get(r.label, None)
if idx is not None:
label_total_counts[idx] += 1
label_confs[idx] += r.attributes["score"]

if len(predicted_bboxes) != 0 and len(result_bboxes) != 0:
if 0 < self.det_conf_thresh:
result_bboxes = [
b
for b in result_bboxes
if self.det_conf_thresh <= b.attributes["score"]
]
if 0 < self.nms_thresh:
result_bboxes = nms(result_bboxes, self.nms_thresh)

for detection in result_bboxes:
for pred_idx, pred in enumerate(predicted_bboxes):
if pred.label != detection.label:
continue

iou = pred.iou(detection)
assert iou == -1 or 0 <= iou and iou <= 1
if iou < iou_thresh:
continue

bbox_total_counts[pred_idx] += 1

conf = detection.attributes["score"]
bbox_confs[pred_idx] += conf

np.multiply.outer(confs, mask, out=current_heatmaps)
heatmaps += current_heatmaps

image_size = model.inputs[0].shape
logit_size = model.outputs[0].shape

batch_size = image_size[0]
if image_size[1] in [1, 3]: # for CxHxW
image_size = (image_size[2], image_size[3])
elif image_size[3] in [1, 3]: # for HxWxC
image_size = (image_size[1], image_size[2])

masks = self.generate_masks(image_size=image_size)
masked_dataset = self.generate_masked_dataset(image, image_size, masks)

saliency = np.zeros((logit_size[1], *image_size), dtype=np.float32)
for batch_id, batch in enumerate(take_by(masked_dataset, batch_size)):
outputs = model.launch(batch)

for sample_id in range(len(batch)):
mask = masks[batch_size * batch_id + sample_id]
for class_idx in range(logit_size[1]):
score = outputs[sample_id][class_idx].attributes["score"]
saliency[class_idx, ...] += score * mask

# [TODO] wonjuleee: support DRISE for detection model explainability
# if isinstance(self.target, Label):
# logits = outputs[sample_id][0].vector
# max_score = logits[self.target.label]
# elif isinstance(self.target, Bbox):
# preds = outputs[sample_id][0]
# max_score = 0
# for box in preds:
# if box[0] == self.target.label:
# confidence, box = box[1], box[2]
# score = iou(self.target.get_bbox, box) * confidence
# if score > max_score:
# max_score = score
# saliency += max_score * mask

if progressive:
yield self.normalize_hmaps(heatmaps.copy(), total_counts)
yield self.normalize_saliency(saliency)

yield self.normalize_hmaps(heatmaps, total_counts)
yield self.normalize_saliency(saliency)
13 changes: 7 additions & 6 deletions src/datumaro/components/shift_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

# ruff: noqa: E501

import itertools
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional

Expand Down Expand Up @@ -79,8 +78,9 @@ def get_activation_stats(self, dataset: IDataset) -> RunningStats1D:
running_stats = RunningStats1D()

for batch in take_by(dataset, self._batch_size):
features = self.model.launch(batch)
running_stats.add(list(itertools.chain(*features)))
outputs = self.model.launch(batch)[0]
features = [outputs[-1]] # extracted feature vector of googlenet-v4
running_stats.add(features)

return running_stats

Expand All @@ -99,10 +99,11 @@ def get_activation_stats(self, dataset: IDataset) -> Dict[int, RunningStats1D]:
inputs.append(np.atleast_3d(item.media.data))
targets.append(ann.label)

features = self.model.launch(batch)
outputs = self.model.launch(batch)[0]
features = [outputs[-1]] # extracted feature vector of googlenet-v4

for feat, target in zip(features, targets):
running_stats[target].add(feat)
for target in targets:
running_stats[target].add(features)

return running_stats

Expand Down
8 changes: 8 additions & 0 deletions src/datumaro/plugins/openvino_plugin/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,14 @@ def __init__(
self._check_model_support(self._network, self._device)
self._load_executable_net()

@property
def inputs(self):
return self._network.inputs

@property
def outputs(self):
return self._network.outputs

def _check_model_support(self, net, device):
not_supported_layers = set(
name for name, dev in self._core.query_model(net, device).items() if not dev
Expand Down
Loading

0 comments on commit 817b44e

Please sign in to comment.