Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Greyscale image support #176

Open
wants to merge 12 commits into
base: v1.0.0
Choose a base branch
from
3 changes: 2 additions & 1 deletion examples/cifar/train_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def make_dataloaders(train_dataset=None, val_dataset=None, batch_size=None, num_
Convert(ch.float16),
torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])

ordering = OrderOption.RANDOM if name == 'train' else OrderOption.SEQUENTIAL

loaders[name] = Loader(paths[name], batch_size=batch_size, num_workers=num_workers,
Expand Down Expand Up @@ -145,6 +145,7 @@ def construct_model():
model = model.to(memory_format=ch.channels_last).cuda()
return model


@param('training.lr')
@param('training.epochs')
@param('training.momentum')
Expand Down
71 changes: 46 additions & 25 deletions ffcv/fields/rgb_image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from functools import partial
from dataclasses import replace
from typing import Optional, Callable, TYPE_CHECKING, Tuple, Type

Expand All @@ -23,8 +24,11 @@
IMAGE_MODES['raw'] = 1


def encode_jpeg(numpy_image, quality):
numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
def encode_jpeg(numpy_image, quality, is_rgb):
if is_rgb:
numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)

# TODO this def assumes rgb lol
success, result = cv2.imencode('.jpg', numpy_image,
[int(cv2.IMWRITE_JPEG_QUALITY), quality])

Expand Down Expand Up @@ -86,7 +90,9 @@ class SimpleRGBImageDecoder(Operation):

It only supports dataset with constant image resolution and will simply read (potentially decompress) and pass the images as is.
"""
def __init__(self):
def __init__(self, is_rgb):
self.is_rgb = is_rgb
self.channels = 3 if is_rgb else 1
super().__init__()

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]:
Expand All @@ -102,7 +108,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
instead."""
raise TypeError(msg)

biggest_shape = (max_height, max_width, 3)
biggest_shape = (max_height, max_width, self.channels)
my_dtype = np.dtype('<u1')

return (
Expand All @@ -119,6 +125,7 @@ def generate_code(self) -> Callable:
raw = IMAGE_MODES['raw']
my_range = Compiler.get_iterator()
my_memcpy = Compiler.compile(memcpy)
is_rgb = self.is_rgb

def decode(batch_indices, destination, metadata, storage_state):
for dst_ix in my_range(len(batch_indices)):
Expand All @@ -128,8 +135,9 @@ def decode(batch_indices, destination, metadata, storage_state):
height, width = field['height'], field['width']

if field['mode'] == jpg:
imdecode_c(image_data, destination[dst_ix],
height, width, height, width, 0, 0, 1, 1, False, False)
imdecode_c(image_data, destination[dst_ix], height, width,
height, width, 0, 0, 1, 1, False, False,
is_rgb)
else:
my_memcpy(image_data, destination[dst_ix])

Expand All @@ -138,14 +146,17 @@ def decode(batch_indices, destination, metadata, storage_state):
decode.is_parallel = True
return decode

class SimpleGrayscaleImageDecoder(SimpleRGBImageDecoder):
def __init__(self):
super().__init__(is_rgb=False)

class ResizedCropRGBImageDecoder(SimpleRGBImageDecoder, metaclass=ABCMeta):
"""Abstract decoder for :class:`~ffcv.fields.RGBImageField` that performs a crop and and a resize operation.

It supports both variable and constant resolution datasets.
"""
def __init__(self, output_size):
super().__init__()
def __init__(self, output_size, is_rgb):
super().__init__(is_rgb)
self.output_size = output_size

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]:
Expand All @@ -154,19 +165,19 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
# We convert to uint64 to avoid overflows
self.max_width = np.uint64(widths.max())
self.max_height = np.uint64(heights.max())
output_shape = (self.output_size[0], self.output_size[1], 3)
output_shape = (self.output_size[0], self.output_size[1], self.channels)
my_dtype = np.dtype('<u1')

channels = np.uint64(self.channels)
return (
replace(previous_state, jit_mode=True,
shape=output_shape, dtype=my_dtype),
(AllocationQuery(output_shape, my_dtype),
AllocationQuery((self.max_height * self.max_width * np.uint64(3),), my_dtype),
AllocationQuery((self.max_height * self.max_width * channels,), my_dtype),
)
)

def generate_code(self) -> Callable:

jpg = IMAGE_MODES['jpg']

mem_read = self.memory_read
Expand All @@ -177,6 +188,8 @@ def generate_code(self) -> Callable:

scale = self.scale
ratio = self.ratio
is_rgb = self.is_rgb
channels = self.channels
if isinstance(scale, tuple):
scale = np.array(scale)
if isinstance(ratio, tuple):
Expand All @@ -193,21 +206,21 @@ def decode(batch_indices, my_storage, metadata, storage_state):

if field['mode'] == jpg:
temp_buffer = temp_storage[dst_ix]
imdecode_c(image_data, temp_buffer,
height, width, height, width, 0, 0, 1, 1, False, False)
selected_size = 3 * height * width
imdecode_c(image_data, temp_buffer, height, width, height,
width, 0, 0, 1, 1, False, False, is_rgb)
selected_size = channels * height * width
temp_buffer = temp_buffer.reshape(-1)[:selected_size]
temp_buffer = temp_buffer.reshape(height, width, 3)

temp_buffer = temp_buffer.reshape(height, width, channels)
else:
temp_buffer = image_data.reshape(height, width, 3)
temp_buffer = image_data.reshape(height, width, channels)

i, j, h, w = get_crop_c(height, width, scale, ratio)

resize_crop_c(temp_buffer, i, i + h, j, j + w,
destination[dst_ix])
destination[dst_ix], is_rgb)

return destination[:len(batch_indices)]

decode.is_parallel = True
return decode

Expand All @@ -231,8 +244,9 @@ class RandomResizedCropRGBImageDecoder(ResizedCropRGBImageDecoder):
ratio : Tuple[float]
The range of potential aspect ratios that can be randomly sampled
"""
def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3)):
super().__init__(output_size)
def __init__(self, output_size, scale=(0.08, 1.0), ratio=(0.75, 4/3),
is_rgb=True):
super().__init__(output_size, is_rgb=is_rgb)
self.scale = scale
self.ratio = ratio
self.output_size = output_size
Expand All @@ -255,8 +269,8 @@ class CenterCropRGBImageDecoder(ResizedCropRGBImageDecoder):
ratio of (crop size) / (min side length)
"""
# output size: resize crop size -> output size
def __init__(self, output_size, ratio):
super().__init__(output_size)
def __init__(self, output_size, ratio, is_rgb=True):
super().__init__(output_size, is_rgb=is_rgb)
self.scale = None
self.ratio = ratio

Expand Down Expand Up @@ -291,12 +305,13 @@ class RGBImageField(Field):
"""
def __init__(self, write_mode='raw', max_resolution: int = None,
smart_threshold: int = None, jpeg_quality: int = 90,
compress_probability: float = 0.5) -> None:
compress_probability: float = 0.5, is_rgb: bool = True) -> None:
self.write_mode = write_mode
self.smart_threshold = smart_threshold
self.max_resolution = max_resolution
self.jpeg_quality = int(jpeg_quality)
self.proportion = compress_probability
self.is_rgb = is_rgb

@property
def metadata_type(self) -> np.dtype:
Expand All @@ -308,7 +323,10 @@ def metadata_type(self) -> np.dtype:
])

def get_decoder_class(self) -> Type[Operation]:
return SimpleRGBImageDecoder
if self.is_rgb:
return SimpleRGBImageDecoder # TODO
else:
return SimpleGrayscaleImageDecoder # TODO

@staticmethod
def from_binary(binary: ARG_TYPE) -> Field:
Expand All @@ -327,7 +345,10 @@ def encode(self, destination, image, malloc):
if image.dtype != np.uint8:
raise ValueError("Image type has to be uint8")

if image.shape[2] != 3:
shape = image.shape
is_ok_grayscale = len(shape) == 2 and not self.is_rgb
is_ok_rgb = len(shape) == 3 and shape[2] == 3 and self.is_rgb
if not (is_ok_rgb or is_ok_grayscale):
raise ValueError(f"Invalid shape for rgb image: {image.shape}")

assert image.dtype == np.uint8
Expand Down
13 changes: 7 additions & 6 deletions ffcv/libffcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,33 @@ def read(fileno:int, destination:np.ndarray, offset:int):


ctypes_resize = lib.resize
ctypes_resize.argtypes = 11 * [c_int64]
ctypes_resize.argtypes = (11 * [c_int64]) + [c_bool]

def resize_crop(source, start_row, end_row, start_col, end_col, destination):
def resize_crop(source, start_row, end_row, start_col, end_col, destination,
is_rgb):
ctypes_resize(0,
source.ctypes.data,
source.shape[0], source.shape[1],
start_row, end_row, start_col, end_col,
destination.ctypes.data,
destination.shape[0], destination.shape[1])
destination.shape[0], destination.shape[1], is_rgb)

# Extract and define the interface of imdeocde
ctypes_imdecode = lib.imdecode
ctypes_imdecode.argtypes = [
c_void_p, c_uint64, c_uint32, c_uint32, c_void_p, c_uint32, c_uint32,
c_uint32, c_uint32, c_uint32, c_uint32, c_bool, c_bool
c_uint32, c_uint32, c_uint32, c_uint32, c_bool, c_bool, c_bool
]

def imdecode(source: np.ndarray, dst: np.ndarray,
source_height: int, source_width: int,
crop_height=None, crop_width=None,
offset_x=0, offset_y=0, scale_factor_num=1, scale_factor_denom=1,
enable_crop=False, do_flip=False):
enable_crop=False, do_flip=False, is_rgb=True):
return ctypes_imdecode(source.ctypes.data, source.size,
source_height, source_width, dst.ctypes.data,
crop_height, crop_width, offset_x, offset_y, scale_factor_num, scale_factor_denom,
enable_crop, do_flip)
enable_crop, do_flip, is_rgb)


ctypes_memcopy = lib.my_memcpy
Expand Down
1 change: 1 addition & 0 deletions ffcv/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(self,
# We check if the user disabled this field
if operations is None:
continue

if not isinstance(operations[0], DecoderClass):
msg = "The first operation of the pipeline for "
msg += f"'{field_name}' has to be a subclass of "
Expand Down
1 change: 1 addition & 0 deletions ffcv/memory_managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, reader:Reader):
self.ptrs = self.ptrs[order]
self.sizes = self.sizes[order]

print('initi memory manager', len(self.ptrs), len(self.sizes))
self.ptr_to_size = dict(zip(self.ptrs, self.sizes))

# We extract the page number by shifting the address corresponding
Expand Down
3 changes: 2 additions & 1 deletion ffcv/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from .translate import RandomTranslate
from .mixup import ImageMixup, LabelMixup, MixupToOneHot
from .module import ModuleWrapper
from .colorjitter import ColorJitter

__all__ = ['ToTensor', 'ToDevice',
'ToTorchImage', 'NormalizeImage',
'Convert', 'Squeeze', 'View',
'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate',
'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot',
'Poison', 'ReplaceLabel',
'ModuleWrapper']
'ModuleWrapper', 'ColorJitter']
Loading