From ce83038a5c9f372fd97e4b66073dc2081304d719 Mon Sep 17 00:00:00 2001 From: Logan Engstrom Date: Sun, 13 Feb 2022 17:57:01 -0500 Subject: [PATCH 1/9] added greyscale --- ffcv/fields/rgb_image.py | 54 ++++++++++++++++++++++++---------------- libffcv/libffcv.cpp | 28 ++++++++++++++++----- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py index eb1a1511..fe5df119 100644 --- a/ffcv/fields/rgb_image.py +++ b/ffcv/fields/rgb_image.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from functools import partial from dataclasses import replace from typing import Optional, Callable, TYPE_CHECKING, Tuple, Type @@ -23,8 +24,10 @@ IMAGE_MODES['raw'] = 1 -def encode_jpeg(numpy_image, quality): - numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) +def encode_jpeg(numpy_image, quality, is_rgb): + if is_rgb: + numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) + success, result = cv2.imencode('.jpg', numpy_image, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) @@ -86,7 +89,9 @@ class SimpleRGBImageDecoder(Operation): It only supports dataset with constant image resolution and will simply read (potentially decompress) and pass the images as is. """ - def __init__(self): + def __init__(self, is_rgb): + self.is_rgb = is_rgb + self.channels = 3 if is_rgb else 1 super().__init__() def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]: @@ -102,7 +107,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca instead.""" raise TypeError(msg) - biggest_shape = (max_height, max_width, 3) + biggest_shape = (max_height, max_width, self.channels) my_dtype = np.dtype(' Callable: raw = IMAGE_MODES['raw'] my_range = Compiler.get_iterator() my_memcpy = Compiler.compile(memcpy) + is_rgb = self.is_rgb def decode(batch_indices, destination, metadata, storage_state): for dst_ix in my_range(len(batch_indices)): @@ -128,8 +134,9 @@ def decode(batch_indices, destination, metadata, storage_state): height, width = field['height'], field['width'] if field['mode'] == jpg: - imdecode_c(image_data, destination[dst_ix], - height, width, height, width, 0, 0, 1, 1, False, False) + imdecode_c(image_data, destination[dst_ix], height, width, + height, width, 0, 0, 1, 1, False, False, + is_rgb) else: my_memcpy(image_data, destination[dst_ix]) @@ -144,8 +151,8 @@ class ResizedCropRGBImageDecoder(SimpleRGBImageDecoder, metaclass=ABCMeta): It supports both variable and constant resolution datasets. """ - def __init__(self, output_size): - super().__init__() + def __init__(self, output_size, is_rgb): + super().__init__(is_rgb) self.output_size = output_size def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]: @@ -154,19 +161,19 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca # We convert to uint64 to avoid overflows self.max_width = np.uint64(widths.max()) self.max_height = np.uint64(heights.max()) - output_shape = (self.output_size[0], self.output_size[1], 3) + output_shape = (self.output_size[0], self.output_size[1], self.channels) my_dtype = np.dtype(' Callable: - jpg = IMAGE_MODES['jpg'] mem_read = self.memory_read @@ -177,6 +184,8 @@ def generate_code(self) -> Callable: scale = self.scale ratio = self.ratio + is_rgb = self.is_rgb + channels = self.channels if isinstance(scale, tuple): scale = np.array(scale) if isinstance(ratio, tuple): @@ -193,21 +202,21 @@ def decode(batch_indices, my_storage, metadata, storage_state): if field['mode'] == jpg: temp_buffer = temp_storage[dst_ix] - imdecode_c(image_data, temp_buffer, - height, width, height, width, 0, 0, 1, 1, False, False) - selected_size = 3 * height * width + imdecode_c(image_data, temp_buffer, height, width, height, + width, 0, 0, 1, 1, False, False, is_rgb) + selected_size = channels * height * width temp_buffer = temp_buffer.reshape(-1)[:selected_size] - temp_buffer = temp_buffer.reshape(height, width, 3) - + temp_buffer = temp_buffer.reshape(height, width, channels) else: - temp_buffer = image_data.reshape(height, width, 3) + temp_buffer = image_data.reshape(height, width, channels) i, j, h, w = get_crop_c(height, width, scale, ratio) resize_crop_c(temp_buffer, i, i + h, j, j + w, - destination[dst_ix]) + destination[dst_ix], is_rgb) return destination[:len(batch_indices)] + decode.is_parallel = True return decode @@ -291,12 +300,13 @@ class RGBImageField(Field): """ def __init__(self, write_mode='raw', max_resolution: int = None, smart_threshold: int = None, jpeg_quality: int = 90, - compress_probability: float = 0.5) -> None: + compress_probability: float = 0.5, is_rgb: bool = True) -> None: self.write_mode = write_mode self.smart_threshold = smart_threshold self.max_resolution = max_resolution self.jpeg_quality = int(jpeg_quality) self.proportion = compress_probability + self.is_rgb = is_rgb @property def metadata_type(self) -> np.dtype: @@ -308,7 +318,7 @@ def metadata_type(self) -> np.dtype: ]) def get_decoder_class(self) -> Type[Operation]: - return SimpleRGBImageDecoder + return partial(SimpleRGBImageDecoder, is_rgb=self.is_rgb) @staticmethod def from_binary(binary: ARG_TYPE) -> Field: @@ -327,7 +337,9 @@ def encode(self, destination, image, malloc): if image.dtype != np.uint8: raise ValueError("Image type has to be uint8") - if image.shape[2] != 3: + is_ok_rgb = image.shape[2] == 3 and self.is_rgb + is_ok_grayscale = image.shape[2] == 1 and not self.is_rgb + if not (is_ok_rgb or is_ok_grayscale): raise ValueError(f"Invalid shape for rgb image: {image.shape}") assert image.dtype == np.uint8 diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index 7bae23ba..1e710499 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -25,11 +25,16 @@ extern "C" { void resize(int64_t cresizer, int64_t source_p, int64_t sx, int64_t sy, int64_t start_row, int64_t end_row, int64_t start_col, int64_t end_col, - int64_t dest_p, int64_t tx, int64_t ty) { - // TODO use proper arguments type + int64_t dest_p, int64_t tx, int64_t ty, bool is_rgb){ + int dtype; + if (is_rgb) { + dtype = CV_8UC3; + } else { + dtype = CV_8UC1; + } - cv::Mat source_matrix(sx, sy, CV_8UC3, (uint8_t*) source_p); - cv::Mat dest_matrix(tx, ty, CV_8UC3, (uint8_t*) dest_p); + cv::Mat source_matrix(sx, sy, dtype, (uint8_t*) source_p); + cv::Mat dest_matrix(tx, ty, dtype, (uint8_t*) dest_p); cv::resize(source_matrix.colRange(start_col, end_col).rowRange(start_row, end_row), dest_matrix, dest_matrix.size(), 0, 0, cv::INTER_AREA); } @@ -51,7 +56,8 @@ extern "C" { __uint32_t offset_x, __uint32_t offset_y, __uint32_t scale_num, __uint32_t scale_denom, bool enable_crop, - bool hflip) + bool hflip, + bool is_rgb) { pthread_once(&key_once, make_keys); @@ -94,13 +100,23 @@ extern "C" { dstBuf = input_buffer; dstSize = input_size; } + + TJPF pixel_format; + + if (is_rgb) { + pixel_format = TJPF_RGB; + } else { + pixel_format = TJPF_GRAY; + } + int result = tjDecompress2(tj_decompressor, dstBuf, dstSize, output_buffer, TJSCALED(crop_width, scaling), 0, TJSCALED(crop_height, scaling), - TJPF_RGB, TJFLAG_FASTDCT | TJFLAG_NOREALLOC); + pixel_format, TJFLAG_FASTDCT | TJFLAG_NOREALLOC); if (do_transform) { tjFree(dstBuf); } + return result; } From 4984f80532ba403822f03fcee0664d48191c40c7 Mon Sep 17 00:00:00 2001 From: Logan Engstrom Date: Tue, 15 Feb 2022 02:48:07 -0500 Subject: [PATCH 2/9] grayscale works --- ffcv/fields/rgb_image.py | 23 ++++++++++++++++------- ffcv/libffcv.py | 13 +++++++------ ffcv/loader/loader.py | 1 + libffcv/libffcv.cpp | 17 ++++++++--------- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py index fe5df119..7e7ecca6 100644 --- a/ffcv/fields/rgb_image.py +++ b/ffcv/fields/rgb_image.py @@ -28,6 +28,7 @@ def encode_jpeg(numpy_image, quality, is_rgb): if is_rgb: numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR) + # TODO this def assumes rgb lol success, result = cv2.imencode('.jpg', numpy_image, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) @@ -145,6 +146,9 @@ def decode(batch_indices, destination, metadata, storage_state): decode.is_parallel = True return decode +class SimpleGrayscaleImageDecoder(SimpleRGBImageDecoder): + def __init__(self): + super().__init__(is_rgb=False) class ResizedCropRGBImageDecoder(SimpleRGBImageDecoder, metaclass=ABCMeta): """Abstract decoder for :class:`~ffcv.fields.RGBImageField` that performs a crop and and a resize operation. @@ -240,8 +244,9 @@ class RandomResizedCropRGBImageDecoder(ResizedCropRGBImageDecoder): ratio : Tuple[float] The range of potential aspect ratios that can be randomly sampled """ - def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3)): - super().__init__(output_size) + def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3), + is_rgb=True): + super().__init__(output_size, is_rgb=is_rgb) self.scale = scale self.ratio = ratio self.output_size = output_size @@ -264,8 +269,8 @@ class CenterCropRGBImageDecoder(ResizedCropRGBImageDecoder): ratio of (crop size) / (min side length) """ # output size: resize crop size -> output size - def __init__(self, output_size, ratio): - super().__init__(output_size) + def __init__(self, output_size, ratio, is_rgb=True): + super().__init__(output_size, is_rgb=is_rgb) self.scale = None self.ratio = ratio @@ -318,7 +323,10 @@ def metadata_type(self) -> np.dtype: ]) def get_decoder_class(self) -> Type[Operation]: - return partial(SimpleRGBImageDecoder, is_rgb=self.is_rgb) + if self.is_rgb: + return SimpleRGBImageDecoder # TODO + else: + return SimpleGrayscaleImageDecoder # TODO @staticmethod def from_binary(binary: ARG_TYPE) -> Field: @@ -337,8 +345,9 @@ def encode(self, destination, image, malloc): if image.dtype != np.uint8: raise ValueError("Image type has to be uint8") - is_ok_rgb = image.shape[2] == 3 and self.is_rgb - is_ok_grayscale = image.shape[2] == 1 and not self.is_rgb + shape = image.shape + is_ok_grayscale = len(shape) == 2 and not self.is_rgb + is_ok_rgb = len(shape) == 3 and shape[2] == 3 and self.is_rgb if not (is_ok_rgb or is_ok_grayscale): raise ValueError(f"Invalid shape for rgb image: {image.shape}") diff --git a/ffcv/libffcv.py b/ffcv/libffcv.py index 52219f3c..2527780e 100644 --- a/ffcv/libffcv.py +++ b/ffcv/libffcv.py @@ -15,32 +15,33 @@ def read(fileno:int, destination:np.ndarray, offset:int): ctypes_resize = lib.resize -ctypes_resize.argtypes = 11 * [c_int64] +ctypes_resize.argtypes = (11 * [c_int64]) + [c_bool] -def resize_crop(source, start_row, end_row, start_col, end_col, destination): +def resize_crop(source, start_row, end_row, start_col, end_col, destination, + is_rgb): ctypes_resize(0, source.ctypes.data, source.shape[0], source.shape[1], start_row, end_row, start_col, end_col, destination.ctypes.data, - destination.shape[0], destination.shape[1]) + destination.shape[0], destination.shape[1], is_rgb) # Extract and define the interface of imdeocde ctypes_imdecode = lib.imdecode ctypes_imdecode.argtypes = [ c_void_p, c_uint64, c_uint32, c_uint32, c_void_p, c_uint32, c_uint32, - c_uint32, c_uint32, c_uint32, c_uint32, c_bool, c_bool + c_uint32, c_uint32, c_uint32, c_uint32, c_bool, c_bool, c_bool ] def imdecode(source: np.ndarray, dst: np.ndarray, source_height: int, source_width: int, crop_height=None, crop_width=None, offset_x=0, offset_y=0, scale_factor_num=1, scale_factor_denom=1, - enable_crop=False, do_flip=False): + enable_crop=False, do_flip=False, is_rgb=True): return ctypes_imdecode(source.ctypes.data, source.size, source_height, source_width, dst.ctypes.data, crop_height, crop_width, offset_x, offset_y, scale_factor_num, scale_factor_denom, - enable_crop, do_flip) + enable_crop, do_flip, is_rgb) ctypes_memcopy = lib.my_memcpy diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index 21cdd104..fac9504a 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -170,6 +170,7 @@ def __init__(self, # We check if the user disabled this field if operations is None: continue + if not isinstance(operations[0], DecoderClass): msg = "The first operation of the pipeline for " msg += f"'{field_name}' has to be a subclass of " diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index 1e710499..87db570b 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -49,15 +49,14 @@ extern "C" { } int imdecode(unsigned char *input_buffer, __uint64_t input_size, - __uint32_t source_height, __uint32_t source_width, - - unsigned char *output_buffer, - __uint32_t crop_height, __uint32_t crop_width, - __uint32_t offset_x, __uint32_t offset_y, - __uint32_t scale_num, __uint32_t scale_denom, - bool enable_crop, - bool hflip, - bool is_rgb) + __uint32_t source_height, __uint32_t source_width, + unsigned char *output_buffer, + __uint32_t crop_height, __uint32_t crop_width, + __uint32_t offset_x, __uint32_t offset_y, + __uint32_t scale_num, __uint32_t scale_denom, + bool enable_crop, + bool hflip, + bool is_rgb) { pthread_once(&key_once, make_keys); From 64bd2b9e9c9fc3779ba13ef958ae479ecfac9c7f Mon Sep 17 00:00:00 2001 From: Logan Engstrom Date: Thu, 17 Feb 2022 17:57:34 -0500 Subject: [PATCH 3/9] "it works" checkpoint --- ffcv/transforms/normalize.py | 49 +++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/ffcv/transforms/normalize.py b/ffcv/transforms/normalize.py index a04f45e7..2a385451 100644 --- a/ffcv/transforms/normalize.py +++ b/ffcv/transforms/normalize.py @@ -3,6 +3,7 @@ """ from collections.abc import Sequence from typing import Tuple +from xmlrpc.client import boolean import numpy as np import torch as ch @@ -37,13 +38,16 @@ class NormalizeImage(Operation): """ def __init__(self, mean: np.ndarray, std: np.ndarray, - type: np.dtype): + type: np.dtype, is_rgb=True): super().__init__() + self.is_rgb = is_rgb + table = (np.arange(256)[:, None] - mean[None, :]) / std[None, :] self.original_dtype = type table = table.astype(type) if type == np.float16: type = np.int16 + self.dtype = type table = table.view(type) self.lookup_table = table @@ -56,13 +60,16 @@ def generate_code(self) -> Callable: return self.generate_code_gpu() def generate_code_gpu(self) -> Callable: - # We only import cupy if it's truly needed import cupy as cp import pytorch_pfn_extras as ppe tn = np.zeros((), dtype=self.dtype).dtype.name - kernel = cp.ElementwiseKernel(f'uint8 input, raw {tn} table', f'{tn} output', 'output = table[input * 3 + i % 3];') + if self.is_rgb: + kernel = cp.ElementwiseKernel(f'uint8 input, raw {tn} table', f'{tn} output', 'output = table[input * 3 + i % 3];') + else: + kernel = cp.ElementwiseKernel(f'uint8 input, raw {tn} table', f'{tn} output', 'output = table[input];') + final_type = ch_dtype_from_numpy(self.original_dtype) s = self def normalize_convert(images, result): @@ -91,18 +98,30 @@ def generate_code_cpu(self) -> Callable: table = self.lookup_table.view(dtype=self.dtype) my_range = Compiler.get_iterator() - def normalize_convert(images, result, indices): - result_flat = result.reshape(result.shape[0], -1, 3) - num_pixels = result_flat.shape[1] - for i in my_range(len(indices)): - image = images[i].reshape(num_pixels, 3) - for px in range(num_pixels): - # Just in case llvm forgets to unroll this one - result_flat[i, px, 0] = table[image[px, 0], 0] - result_flat[i, px, 1] = table[image[px, 1], 1] - result_flat[i, px, 2] = table[image[px, 2], 2] - - return result + if self.is_rgb: + def normalize_convert(images, result, indices): + result_flat = result.reshape(result.shape[0], -1, 3) + num_pixels = result_flat.shape[1] + for i in my_range(len(indices)): + image = images[i].reshape(num_pixels, 3) + for px in range(num_pixels): + # Just in case llvm forgets to unroll this one + result_flat[i, px, 0] = table[image[px, 0], 0] + result_flat[i, px, 1] = table[image[px, 1], 1] + result_flat[i, px, 2] = table[image[px, 2], 2] + + return result + else: + def normalize_convert(images, result, indices): + result_flat = result.reshape(result.shape[0], -1, 1) + num_pixels = result_flat.shape[1] + for i in my_range(len(indices)): + image = images[i].reshape(num_pixels, 1) + for px in range(num_pixels): + # Just in case llvm forgets to unroll this one + result_flat[i, px, 0] = table[image[px, 0], 0] + + return result normalize_convert.is_parallel = True normalize_convert.with_indices = True From c01ff448b0110751aa61e68811fea30ec2365776 Mon Sep 17 00:00:00 2001 From: Florian Bordes Date: Wed, 23 Mar 2022 06:10:30 -0700 Subject: [PATCH 4/9] Add ColorJitter transformation --- ffcv/transforms/__init__.py | 3 +- ffcv/transforms/colorjitter.py | 140 +++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 ffcv/transforms/colorjitter.py diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index bc8fa321..56d61811 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -9,6 +9,7 @@ from .translate import RandomTranslate from .mixup import ImageMixup, LabelMixup, MixupToOneHot from .module import ModuleWrapper +from .colorjitter import ColorJitter __all__ = ['ToTensor', 'ToDevice', 'ToTorchImage', 'NormalizeImage', @@ -16,4 +17,4 @@ 'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate', 'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot', 'Poison', 'ReplaceLabel', - 'ModuleWrapper'] \ No newline at end of file + 'ModuleWrapper', 'ColorJitter'] \ No newline at end of file diff --git a/ffcv/transforms/colorjitter.py b/ffcv/transforms/colorjitter.py new file mode 100644 index 00000000..b407e5db --- /dev/null +++ b/ffcv/transforms/colorjitter.py @@ -0,0 +1,140 @@ +""" +ColorJitter +Code for Brightness, Contrast and Saturation adapted from +https://github.com/pytorch/vision/blob/main/torchvision/transforms/functional_tensor.py +Code for Hue adapted from: +https://sanje2v.wordpress.com/2021/01/11/accelerating-data-transforms/ +https://stackoverflow.com/questions/8507885 +""" +import numpy as np +from numpy.random import rand +from typing import Callable, Optional, Tuple +from dataclasses import replace +from ..pipeline.allocation_query import AllocationQuery +from ..pipeline.operation import Operation +from ..pipeline.state import State +from ..pipeline.compiler import Compiler +import numbers +import numba as nb + +class ColorJitter(Operation): + """Add ColorJitter with probability jitter_prob. + Operates on raw arrays (not tensors). + + Parameters + ---------- + jitter_prob : float, The probability with which to apply ColorJitter. + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__(self, jitter_prob, brightness=0, contrast=0, saturation=0, hue=0): + super().__init__() + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.jitter_prob = jitter_prob + assert self.jitter_prob >= 0 and self.jitter_prob <= 1 + + def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError(f"If {name} is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}") + else: + raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + jitter_prob = self.jitter_prob + apply_brightness = (self.brightness is not None) + if apply_brightness: + brightness_min, brightness_max = self.brightness + apply_contrast = (self.contrast is not None) + if apply_contrast: + contrast_min, contrast_max = self.contrast + apply_saturation = (self.saturation is not None) + if apply_saturation: + saturation_min, saturation_max = self.saturation + apply_hue = (self.hue is not None) + if apply_hue: + hue_min, hue_max = self.hue + + def color_jitter(images, dst): + should_jitter = rand(images.shape[0]) < jitter_prob + for i in my_range(images.shape[0]): + if should_jitter[i]: + img = images[i] + # Brightness + if apply_brightness: + # ratio_brightness = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + ratio_brightness = np.random.uniform(brightness_min, brightness_max) + img = ratio_brightness * img + (1.0 - ratio_brightness) * img * 0 + img = np.clip(img, 0, 255) + + # Contrast + if apply_contrast: + # ratio_contrast = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + ratio_contrast = np.random.uniform(contrast_min, contrast_max) + gray = 0.2989 * img[:,:,0:1] + 0.5870 * img[:,:,1:2] + 0.1140 * img[:,:,2:3] + img = ratio_contrast * img + (1.0 - ratio_contrast) * gray.mean() + img = np.clip(img, 0, 255) + + # Saturation + if apply_saturation: + ratio_saturation = np.random.uniform(saturation_min, saturation_max) + dst[i] = 0.2989 * img[:,:,0:1] + 0.5870 * img[:,:,1:2] + 0.1140 * img[:,:,2:3] + img = ratio_saturation * img + (1.0 - ratio_saturation) * dst[i] + img = np.clip(img, 0, 255) + + # Hue + if apply_hue: + img = img / 255. + hue_factor = np.random.uniform(hue_min, hue_max).item() + hue_factor_radians = hue_factor * 2.0 * np.pi + cosA = np.cos(hue_factor_radians) + sinA = np.sin(hue_factor_radians) + hue_rotation_matrix =\ + [[cosA + (1.0 - cosA) / 3.0, 1./3. * (1.0 - cosA) - np.sqrt(1./3.) * sinA, 1./3. * (1.0 - cosA) + np.sqrt(1./3.) * sinA], + [1./3. * (1.0 - cosA) + np.sqrt(1./3.) * sinA, cosA + 1./3.*(1.0 - cosA), 1./3. * (1.0 - cosA) - np.sqrt(1./3.) * sinA], + [1./3. * (1.0 - cosA) - np.sqrt(1./3.) * sinA, 1./3. * (1.0 - cosA) + np.sqrt(1./3.) * sinA, cosA + 1./3. * (1.0 - cosA)]] + hue_rotation_matrix = np.array(hue_rotation_matrix, dtype=img.dtype) + for row in nb.prange(img.shape[0]): + for col in nb.prange(img.shape[1]): + img[row, col, 0] = img[row, col, 0] * hue_rotation_matrix[0, 0] + img[row, col, 1] * hue_rotation_matrix[0, 1] + img[row, col, 2] * hue_rotation_matrix[0, 2] + img[row, col, 1] = img[row, col, 0] * hue_rotation_matrix[1, 0] + img[row, col, 1] * hue_rotation_matrix[1, 1] + img[row, col, 2] * hue_rotation_matrix[1, 2] + img[row, col, 2] = img[row, col, 0] * hue_rotation_matrix[2, 0] + img[row, col, 1] * hue_rotation_matrix[2, 1] + img[row, col, 2] * hue_rotation_matrix[2, 2] + img = np.asarray(np.clip(img, 0, 1)*255.,dtype=np.uint8) + dst[i] = img + else: + dst[i] = images[i] + return dst + + color_jitter.is_parallel = True + return color_jitter + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + return (replace(previous_state, jit_mode=True), AllocationQuery(shape=previous_state.shape, dtype=previous_state.dtype)) + From 0c9aa4175d273fcef976c7d6fbd9e9ce79541700 Mon Sep 17 00:00:00 2001 From: Florian Bordes Date: Wed, 23 Mar 2022 09:31:31 -0700 Subject: [PATCH 5/9] Correst small hue inplace issue --- ffcv/transforms/colorjitter.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ffcv/transforms/colorjitter.py b/ffcv/transforms/colorjitter.py index b407e5db..d5dcace5 100644 --- a/ffcv/transforms/colorjitter.py +++ b/ffcv/transforms/colorjitter.py @@ -16,6 +16,7 @@ from ..pipeline.compiler import Compiler import numbers import numba as nb +from math import sqrt,cos,sin,radians class ColorJitter(Operation): """Add ColorJitter with probability jitter_prob. @@ -111,8 +112,8 @@ def color_jitter(images, dst): # Hue if apply_hue: - img = img / 255. - hue_factor = np.random.uniform(hue_min, hue_max).item() + img = img / 255.0 + hue_factor = np.random.uniform(hue_min, hue_max) hue_factor_radians = hue_factor * 2.0 * np.pi cosA = np.cos(hue_factor_radians) sinA = np.sin(hue_factor_radians) @@ -123,10 +124,11 @@ def color_jitter(images, dst): hue_rotation_matrix = np.array(hue_rotation_matrix, dtype=img.dtype) for row in nb.prange(img.shape[0]): for col in nb.prange(img.shape[1]): - img[row, col, 0] = img[row, col, 0] * hue_rotation_matrix[0, 0] + img[row, col, 1] * hue_rotation_matrix[0, 1] + img[row, col, 2] * hue_rotation_matrix[0, 2] - img[row, col, 1] = img[row, col, 0] * hue_rotation_matrix[1, 0] + img[row, col, 1] * hue_rotation_matrix[1, 1] + img[row, col, 2] * hue_rotation_matrix[1, 2] - img[row, col, 2] = img[row, col, 0] * hue_rotation_matrix[2, 0] + img[row, col, 1] * hue_rotation_matrix[2, 1] + img[row, col, 2] * hue_rotation_matrix[2, 2] - img = np.asarray(np.clip(img, 0, 1)*255.,dtype=np.uint8) + r, g, b = img[row, col, :] + img[row, col, 0] = r * hue_rotation_matrix[0, 0] + g * hue_rotation_matrix[0, 1] + b * hue_rotation_matrix[0, 2] + img[row, col, 1] = r * hue_rotation_matrix[1, 0] + g * hue_rotation_matrix[1, 1] + b * hue_rotation_matrix[1, 2] + img[row, col, 2] = r * hue_rotation_matrix[2, 0] + g * hue_rotation_matrix[2, 1] + b * hue_rotation_matrix[2, 2] + img = np.asarray(np.clip(img * 255., 0, 255), dtype=np.uint8) dst[i] = img else: dst[i] = images[i] From 956c04282404932a750f8daf6cc99a8e66c2e97d Mon Sep 17 00:00:00 2001 From: Florian Bordes Date: Wed, 23 Mar 2022 09:56:06 -0700 Subject: [PATCH 6/9] Remove comment --- ffcv/transforms/colorjitter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ffcv/transforms/colorjitter.py b/ffcv/transforms/colorjitter.py index d5dcace5..548fd7e4 100644 --- a/ffcv/transforms/colorjitter.py +++ b/ffcv/transforms/colorjitter.py @@ -90,14 +90,12 @@ def color_jitter(images, dst): img = images[i] # Brightness if apply_brightness: - # ratio_brightness = np.random.uniform(max(0, 1 - brightness), 1 + brightness) ratio_brightness = np.random.uniform(brightness_min, brightness_max) img = ratio_brightness * img + (1.0 - ratio_brightness) * img * 0 img = np.clip(img, 0, 255) # Contrast if apply_contrast: - # ratio_contrast = np.random.uniform(max(0, 1 - contrast), 1 + contrast) ratio_contrast = np.random.uniform(contrast_min, contrast_max) gray = 0.2989 * img[:,:,0:1] + 0.5870 * img[:,:,1:2] + 0.1140 * img[:,:,2:3] img = ratio_contrast * img + (1.0 - ratio_contrast) * gray.mean() From 6f254366d30ae02274c83898a5b3a54aa74bb28f Mon Sep 17 00:00:00 2001 From: Florian Bordes Date: Wed, 23 Mar 2022 11:45:23 -0700 Subject: [PATCH 7/9] Remove useless import --- ffcv/transforms/colorjitter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ffcv/transforms/colorjitter.py b/ffcv/transforms/colorjitter.py index 548fd7e4..686378a9 100644 --- a/ffcv/transforms/colorjitter.py +++ b/ffcv/transforms/colorjitter.py @@ -16,7 +16,6 @@ from ..pipeline.compiler import Compiler import numbers import numba as nb -from math import sqrt,cos,sin,radians class ColorJitter(Operation): """Add ColorJitter with probability jitter_prob. From ace60c9fba509ca04b137f0f9cc64459ee0e2e11 Mon Sep 17 00:00:00 2001 From: Logan Engstrom Date: Thu, 16 Jun 2022 00:15:08 -0400 Subject: [PATCH 8/9] added --- ffcv/memory_managers/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ffcv/memory_managers/base.py b/ffcv/memory_managers/base.py index 525833a0..265e204a 100644 --- a/ffcv/memory_managers/base.py +++ b/ffcv/memory_managers/base.py @@ -44,6 +44,7 @@ def __init__(self, reader:Reader): self.ptrs = self.ptrs[order] self.sizes = self.sizes[order] + print('initi memory manager', len(self.ptrs), len(self.sizes)) self.ptr_to_size = dict(zip(self.ptrs, self.sizes)) # We extract the page number by shifting the address corresponding From 8098a6b7fe8e3a8ba2355984516b1ddb37509fb8 Mon Sep 17 00:00:00 2001 From: Logan Engstrom Date: Mon, 29 Aug 2022 15:28:08 -0400 Subject: [PATCH 9/9] aded --- examples/cifar/train_cifar.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/cifar/train_cifar.py b/examples/cifar/train_cifar.py index 3160e706..5db405cc 100644 --- a/examples/cifar/train_cifar.py +++ b/examples/cifar/train_cifar.py @@ -94,7 +94,7 @@ def make_dataloaders(train_dataset=None, val_dataset=None, batch_size=None, num_ Convert(ch.float16), torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) - + ordering = OrderOption.RANDOM if name == 'train' else OrderOption.SEQUENTIAL loaders[name] = Loader(paths[name], batch_size=batch_size, num_workers=num_workers, @@ -145,6 +145,7 @@ def construct_model(): model = model.to(memory_format=ch.channels_last).cuda() return model + @param('training.lr') @param('training.epochs') @param('training.momentum')