Skip to content

Commit

Permalink
Vectorized implementation of RandomColorDegeneration and Equalization…
Browse files Browse the repository at this point in the history
… preprocessing layers (#2214)

* Vectorize RandomColorDegeneration

* Vectorize Equalization

* Vectorize Equalization

* Fix Equalization for ragged input
  • Loading branch information
sup3rgiu authored Dec 8, 2023
1 parent 7da0c1e commit 25cb3a1
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 43 deletions.
116 changes: 86 additions & 30 deletions keras_cv/layers/preprocessing/equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501
VectorizedBaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing


@keras_cv_export("keras_cv.layers.Equalization")
class Equalization(BaseImageAugmentationLayer):
class Equalization(VectorizedBaseImageAugmentationLayer):
"""Equalization performs histogram equalization on a channel-wise basis.
Args:
Expand Down Expand Up @@ -52,17 +54,33 @@ def __init__(self, value_range, bins=256, **kwargs):
self.bins = bins
self.value_range = value_range

def equalize_channel(self, image, channel_index):
def equalize_channel(self, images, channel_index):
"""equalize_channel performs histogram equalization on a single channel.
Args:
image: int Tensor with pixels in range [0, 255], RGB format,
with channels last
channel_index: channel to equalize
"""
image = image[..., channel_index]
is_single_image = tf.rank(images) == 4 and tf.shape(images)[0] == 1

images = images[..., channel_index]
# Compute the histogram of the image channel.
histogram = tf.histogram_fixed_width(image, [0, 255], nbins=self.bins)

# If the input is not a batch of images, directly using
# tf.histogram_fixed_width is much faster than using tf.vectorized_map
if is_single_image:
histogram = tf.histogram_fixed_width(
images, [0, 255], nbins=self.bins
)
histogram = tf.expand_dims(histogram, axis=0)
else:
partial_hist = partial(
tf.histogram_fixed_width, value_range=[0, 255], nbins=self.bins
)
histogram = tf.vectorized_map(
partial_hist, images, fallback_to_while_loop=True, warn=True
)

# For the purposes of computing the step, filter out the non-zeros.
# Zeroes are replaced by a big number while calculating min to keep
Expand All @@ -77,56 +95,94 @@ def equalize_channel(self, image, channel_index):
)

step = (
tf.reduce_sum(histogram) - tf.reduce_min(histogram_without_zeroes)
tf.reduce_sum(histogram, axis=-1)
- tf.reduce_min(histogram_without_zeroes, axis=-1)
) // (self.bins - 1)

def build_mapping(histogram, step):
bacth_size = tf.shape(histogram)[0]

# Replace where step is 0 with 1 to avoid division by 0.
# This doesn't change the result, because where step==0 the
# original image is returned
_step = tf.where(
tf.equal(step, 0),
1,
step,
)
_step = tf.expand_dims(_step, -1)

# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lookup_table = (tf.cumsum(histogram) + (step // 2)) // step
lookup_table = (
tf.cumsum(histogram, axis=-1) + (_step // 2)
) // _step

# Shift lookup_table, prepending with 0.
lookup_table = tf.concat([[0], lookup_table[:-1]], 0)
lookup_table = tf.concat(
[tf.tile([[0]], [bacth_size, 1]), lookup_table[..., :-1]],
axis=1,
)

# Clip the counts to be in range. This is done
# in the C code for image.point.
return tf.clip_by_value(lookup_table, 0, 255)

# If step is zero, return the original image. Otherwise, build
# lookup table from the full histogram and step and then index from it.
result = tf.cond(
tf.equal(step, 0),
lambda: image,
lambda: tf.gather(build_mapping(histogram, step), image),
# The lookup table is built for all images,
# regardless of the corresponding value of step.
result = tf.where(
tf.reshape(tf.equal(step, 0), (-1, 1, 1)),
images,
tf.gather(
build_mapping(histogram, step), images, batch_dims=1, axis=1
),
)

return result

def augment_image(self, image, **kwargs):
image = preprocessing.transform_value_range(
image, self.value_range, (0, 255), dtype=self.compute_dtype
def augment_images(self, images, transformations=None, **kwargs):
images = preprocessing.transform_value_range(
images, self.value_range, (0, 255), dtype=self.compute_dtype
)
image = tf.cast(image, tf.int32)
image = tf.map_fn(
lambda channel: self.equalize_channel(image, channel),
tf.range(tf.shape(image)[-1]),
images = tf.cast(images, tf.int32)

images = tf.map_fn(
lambda channel: self.equalize_channel(images, channel),
tf.range(tf.shape(images)[-1]),
)
images = tf.transpose(images, [1, 2, 3, 0])

image = tf.transpose(image, [1, 2, 0])
image = tf.cast(image, self.compute_dtype)
image = preprocessing.transform_value_range(
image, (0, 255), self.value_range, dtype=self.compute_dtype
images = tf.cast(images, self.compute_dtype)
images = preprocessing.transform_value_range(
images, (0, 255), self.value_range, dtype=self.compute_dtype
)
return image
return images

def augment_bounding_boxes(self, bounding_boxes, **kwargs):
return bounding_boxes

def augment_label(self, label, transformation=None, **kwargs):
return label
def augment_labels(self, labels, transformations=None, **kwargs):
return labels

def augment_segmentation_mask(
self, segmentation_mask, transformation, **kwargs
def augment_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
):
return segmentation_mask
return segmentation_masks

def augment_keypoints(self, keypoints, transformations, **kwargs):
return keypoints

def augment_targets(self, targets, transformations, **kwargs):
return targets

def augment_ragged_image(self, image, transformation, **kwargs):
image = tf.expand_dims(image, axis=0)
image = self.augment_images(
images=image, transformations=transformation, **kwargs
)
return tf.squeeze(image, axis=0)

def get_config(self):
config = super().get_config()
Expand Down
41 changes: 28 additions & 13 deletions keras_cv/layers/preprocessing/random_color_degeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501
VectorizedBaseImageAugmentationLayer,
)
from keras_cv.utils import preprocessing


@keras_cv_export("keras_cv.layers.RandomColorDegeneration")
class RandomColorDegeneration(BaseImageAugmentationLayer):
class RandomColorDegeneration(VectorizedBaseImageAugmentationLayer):
"""Randomly performs the color degeneration operation on given images.
The sharpness operation first converts an image to gray scale, then back to
Expand Down Expand Up @@ -57,24 +57,39 @@ def __init__(
)
self.seed = seed

def get_random_transformation(self, **kwargs):
return self.factor(dtype=self.compute_dtype)
def get_random_transformation_batch(self, batch_size, **kwargs):
return self.factor(
shape=(batch_size, 1, 1, 1), dtype=self.compute_dtype
)

def augment_image(self, image, transformation=None, **kwargs):
degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
result = preprocessing.blend(image, degenerate, transformation)
def augment_images(self, images, transformations=None, **kwargs):
degenerates = tf.image.grayscale_to_rgb(
tf.image.rgb_to_grayscale(images)
)
result = preprocessing.blend(images, degenerates, transformations)
return result

def augment_bounding_boxes(self, bounding_boxes, **kwargs):
return bounding_boxes

def augment_label(self, label, transformation=None, **kwargs):
return label
def augment_labels(self, labels, transformations=None, **kwargs):
return labels

def augment_segmentation_mask(
self, segmentation_mask, transformation, **kwargs
def augment_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
):
return segmentation_mask
return segmentation_masks

def augment_keypoints(self, keypoints, transformations, **kwargs):
return keypoints

def augment_targets(self, targets, transformations, **kwargs):
return targets

def augment_ragged_image(self, image, transformation, **kwargs):
return self.augment_images(
image, transformations=transformation, **kwargs
)

def get_config(self):
config = super().get_config()
Expand Down

0 comments on commit 25cb3a1

Please sign in to comment.