Skip to content

Commit

Permalink
Resample RGB images in C++
Browse files Browse the repository at this point in the history
Agg already has RGB resampling with output to RGBA builtin, and we just
need to correctly wire up the corresponding templates. With this RGB
resampling mode, we save the extra copy from RGB to RGBA in NumPy land
that was required for the previous RGBA resampling.
  • Loading branch information
QuLogic committed Jan 10, 2025
1 parent 965efca commit d01ef30
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 58 deletions.
30 changes: 9 additions & 21 deletions lib/matplotlib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ def flush_images():
flush_images()


def _resample(
image_obj, data, out_shape, transform, *, resample=None, alpha=1):
def _resample(image_obj, data, out_shape, transform, *, resample=None, alpha=1):
"""
Convenience wrapper around `._image.resample` to resample *data* to
*out_shape* (with a third dimension if *data* is RGBA) that takes care of
Expand Down Expand Up @@ -204,7 +203,10 @@ def _resample(
interpolation = 'nearest'
else:
interpolation = 'hanning'
out = np.zeros(out_shape + data.shape[2:], data.dtype) # 2D->2D, 3D->3D.
if len(data.shape) == 3:
# Always output RGBA.
out_shape += (4, )
out = np.zeros(out_shape, data.dtype)
if resample is None:
resample = image_obj.get_resample()
_image.resample(data, out, transform,
Expand All @@ -216,20 +218,6 @@ def _resample(
return out


def _rgb_to_rgba(A):
"""
Convert an RGB image to RGBA, as required by the image resample C++
extension.
"""
rgba = np.zeros((A.shape[0], A.shape[1], 4), dtype=A.dtype)
rgba[:, :, :3] = A
if rgba.dtype == np.uint8:
rgba[:, :, 3] = 255
else:
rgba[:, :, 3] = 1.0
return rgba


class _ImageBase(mcolorizer.ColorizingArtist):
"""
Base class for images.
Expand Down Expand Up @@ -508,10 +496,10 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
# alpha channel below.
output_alpha = (255 * alpha) if A.dtype == np.uint8 else alpha
else:
output_alpha = _resample( # resample alpha channel
self, A[..., 3], out_shape, t, alpha=alpha)
output = _resample( # resample rgb channels
self, _rgb_to_rgba(A[..., :3]), out_shape, t, alpha=alpha)
# resample alpha channel
output_alpha = _resample(self, A[..., 3], out_shape, t, alpha=alpha)
# resample rgb channels
output = _resample(self, A[..., :3], out_shape, t, alpha=alpha)
output[..., 3] = output_alpha # recombine rgb and alpha

# output is now either a 2D array of normed (int or float) data
Expand Down
4 changes: 2 additions & 2 deletions lib/matplotlib/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,8 +1577,8 @@ def test__resample_valid_output():
resample(np.zeros((9, 9)), np.zeros((9, 9, 4)))
with pytest.raises(ValueError, match="different dimensionalities"):
resample(np.zeros((9, 9, 4)), np.zeros((9, 9)))
with pytest.raises(ValueError, match="3D input array must be RGBA"):
resample(np.zeros((9, 9, 3)), np.zeros((9, 9, 4)))
with pytest.raises(ValueError, match="3D input array must be RGB"):
resample(np.zeros((9, 9, 2)), np.zeros((9, 9, 4)))
with pytest.raises(ValueError, match="3D output array must be RGBA"):
resample(np.zeros((9, 9, 4)), np.zeros((9, 9, 3)))
with pytest.raises(ValueError, match="mismatched types"):
Expand Down
74 changes: 57 additions & 17 deletions src/_image_resample.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "agg_image_accessors.h"
#include "agg_path_storage.h"
#include "agg_pixfmt_gray.h"
#include "agg_pixfmt_rgb.h"
#include "agg_pixfmt_rgba.h"
#include "agg_renderer_base.h"
#include "agg_renderer_scanline.h"
Expand All @@ -16,6 +17,7 @@
#include "agg_span_allocator.h"
#include "agg_span_converter.h"
#include "agg_span_image_filter_gray.h"
#include "agg_span_image_filter_rgb.h"
#include "agg_span_image_filter_rgba.h"
#include "agg_span_interpolator_adaptor.h"
#include "agg_span_interpolator_linear.h"
Expand Down Expand Up @@ -496,16 +498,38 @@ typedef enum {
} interpolation_e;


// T is rgba if and only if it has an T::r field.
// T is rgb(a) if and only if it has an T::r field.
template<typename T, typename = void> struct is_grayscale : std::true_type {};
template<typename T> struct is_grayscale<T, std::void_t<decltype(T::r)>> : std::false_type {};
template<typename T> constexpr bool is_grayscale_v = is_grayscale<T>::value;


template<typename color_type>
template<typename color_type, bool input_has_alpha>
struct type_mapping
{
using blender_type = std::conditional_t<
using input_blender_type = std::conditional_t<
is_grayscale_v<color_type>,
agg::blender_gray<color_type>,
std::conditional_t<
input_has_alpha,
std::conditional_t<
std::is_same_v<color_type, agg::rgba8>,
fixed_blender_rgba_plain<color_type, agg::order_rgba>,
agg::blender_rgba_plain<color_type, agg::order_rgba>
>,
agg::blender_rgb<color_type, agg::order_rgb>
>
>;
using input_pixfmt_type = std::conditional_t<
is_grayscale_v<color_type>,
agg::pixfmt_alpha_blend_gray<input_blender_type, agg::rendering_buffer>,
std::conditional_t<
input_has_alpha,
agg::pixfmt_alpha_blend_rgba<input_blender_type, agg::rendering_buffer>,
agg::pixfmt_alpha_blend_rgb<input_blender_type, agg::rendering_buffer, 3>
>
>;
using output_blender_type = std::conditional_t<
is_grayscale_v<color_type>,
agg::blender_gray<color_type>,
std::conditional_t<
Expand All @@ -514,25 +538,37 @@ struct type_mapping
agg::blender_rgba_plain<color_type, agg::order_rgba>
>
>;
using pixfmt_type = std::conditional_t<
using output_pixfmt_type = std::conditional_t<
is_grayscale_v<color_type>,
agg::pixfmt_alpha_blend_gray<blender_type, agg::rendering_buffer>,
agg::pixfmt_alpha_blend_rgba<blender_type, agg::rendering_buffer>
agg::pixfmt_alpha_blend_gray<output_blender_type, agg::rendering_buffer>,
agg::pixfmt_alpha_blend_rgba<output_blender_type, agg::rendering_buffer>
>;
template<typename A> using span_gen_affine_type = std::conditional_t<
is_grayscale_v<color_type>,
agg::span_image_resample_gray_affine<A>,
agg::span_image_resample_rgba_affine<A>
std::conditional_t<
input_has_alpha,
agg::span_image_resample_rgba_affine<A>,
agg::span_image_resample_rgb_affine<A>
>
>;
template<typename A, typename B> using span_gen_filter_type = std::conditional_t<
is_grayscale_v<color_type>,
agg::span_image_filter_gray<A, B>,
agg::span_image_filter_rgba<A, B>
std::conditional_t<
input_has_alpha,
agg::span_image_filter_rgba<A, B>,
agg::span_image_filter_rgb<A, B>
>
>;
template<typename A, typename B> using span_gen_nn_type = std::conditional_t<
is_grayscale_v<color_type>,
agg::span_image_filter_gray_nn<A, B>,
agg::span_image_filter_rgba_nn<A, B>
std::conditional_t<
input_has_alpha,
agg::span_image_filter_rgba_nn<A, B>,
agg::span_image_filter_rgb_nn<A, B>
>
>;
};

Expand Down Expand Up @@ -686,16 +722,16 @@ static void get_filter(const resample_params_t &params,
}


template<typename color_type>
template<typename color_type, bool input_has_alpha = true>
void resample(
const void *input, int in_width, int in_height,
void *output, int out_width, int out_height,
resample_params_t &params)
{
using type_mapping_t = type_mapping<color_type>;
using type_mapping_t = type_mapping<color_type, input_has_alpha>;

using input_pixfmt_t = typename type_mapping_t::pixfmt_type;
using output_pixfmt_t = typename type_mapping_t::pixfmt_type;
using input_pixfmt_t = typename type_mapping_t::input_pixfmt_type;
using output_pixfmt_t = typename type_mapping_t::output_pixfmt_type;

using renderer_t = agg::renderer_base<output_pixfmt_t>;
using rasterizer_t = agg::rasterizer_scanline_aa<agg::rasterizer_sl_clip_dbl>;
Expand All @@ -711,9 +747,13 @@ void resample(
using arbitrary_interpolator_t =
agg::span_interpolator_adaptor<agg::span_interpolator_linear<>, lookup_distortion>;

size_t itemsize = sizeof(color_type);
size_t in_itemsize = sizeof(color_type);
size_t out_itemsize = sizeof(color_type);
if (is_grayscale<color_type>::value) {
itemsize /= 2; // agg::grayXX includes an alpha channel which we don't have.
in_itemsize /= 2; // agg::grayXX includes an alpha channel which we don't have.
out_itemsize /= 2;
} else if(!input_has_alpha) {
in_itemsize = in_itemsize / 4 * 3;
}

if (params.interpolation != NEAREST &&
Expand All @@ -733,13 +773,13 @@ void resample(

agg::rendering_buffer input_buffer;
input_buffer.attach(
(unsigned char *)input, in_width, in_height, in_width * itemsize);
(unsigned char *)input, in_width, in_height, in_width * in_itemsize);
input_pixfmt_t input_pixfmt(input_buffer);
image_accessor_t input_accessor(input_pixfmt);

agg::rendering_buffer output_buffer;
output_buffer.attach(
(unsigned char *)output, out_width, out_height, out_width * itemsize);
(unsigned char *)output, out_width, out_height, out_width * out_itemsize);
output_pixfmt_t output_pixfmt(output_buffer);
renderer_t renderer(output_pixfmt);

Expand Down
50 changes: 32 additions & 18 deletions src/_image_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ image_resample(py::array input_array,
throw std::invalid_argument("Input array must be a 2D or 3D array");
}

if (ndim == 3 && input_array.shape(2) != 4) {
throw std::invalid_argument(
"3D input array must be RGBA with shape (M, N, 4), has trailing dimension of {}"_s.format(
input_array.shape(2)));
py::ssize_t ncomponents = 0;
if (ndim == 3) {
ncomponents = input_array.shape(2);
if (ncomponents != 3 && ncomponents != 4) {
throw std::invalid_argument(
"3D input array must be RGB with shape (M, N, 3) or RGBA with shape (M, N, 4), "
"has trailing dimension of {}"_s.format(ncomponents));
}
}

// Ensure input array is contiguous, regardless of dtype
Expand Down Expand Up @@ -173,21 +177,31 @@ image_resample(py::array input_array,

if (auto resampler =
(ndim == 2) ? (
(dtype.equal(py::dtype::of<std::uint8_t>())) ? resample<agg::gray8> :
(dtype.equal(py::dtype::of<std::int8_t>())) ? resample<agg::gray8> :
(dtype.equal(py::dtype::of<std::uint16_t>())) ? resample<agg::gray16> :
(dtype.equal(py::dtype::of<std::int16_t>())) ? resample<agg::gray16> :
(dtype.equal(py::dtype::of<float>())) ? resample<agg::gray32> :
(dtype.equal(py::dtype::of<double>())) ? resample<agg::gray64> :
dtype.equal(py::dtype::of<std::uint8_t>()) ? resample<agg::gray8> :
dtype.equal(py::dtype::of<std::int8_t>()) ? resample<agg::gray8> :
dtype.equal(py::dtype::of<std::uint16_t>()) ? resample<agg::gray16> :
dtype.equal(py::dtype::of<std::int16_t>()) ? resample<agg::gray16> :
dtype.equal(py::dtype::of<float>()) ? resample<agg::gray32> :
dtype.equal(py::dtype::of<double>()) ? resample<agg::gray64> :
nullptr) : (
// ndim == 3
(dtype.equal(py::dtype::of<std::uint8_t>())) ? resample<agg::rgba8> :
(dtype.equal(py::dtype::of<std::int8_t>())) ? resample<agg::rgba8> :
(dtype.equal(py::dtype::of<std::uint16_t>())) ? resample<agg::rgba16> :
(dtype.equal(py::dtype::of<std::int16_t>())) ? resample<agg::rgba16> :
(dtype.equal(py::dtype::of<float>())) ? resample<agg::rgba32> :
(dtype.equal(py::dtype::of<double>())) ? resample<agg::rgba64> :
nullptr)) {
// ndim == 3
(ncomponents == 4) ? (
dtype.equal(py::dtype::of<std::uint8_t>()) ? resample<agg::rgba8, true> :
dtype.equal(py::dtype::of<std::int8_t>()) ? resample<agg::rgba8, true> :
dtype.equal(py::dtype::of<std::uint16_t>()) ? resample<agg::rgba16, true> :
dtype.equal(py::dtype::of<std::int16_t>()) ? resample<agg::rgba16, true> :
dtype.equal(py::dtype::of<float>()) ? resample<agg::rgba32, true> :
dtype.equal(py::dtype::of<double>()) ? resample<agg::rgba64, true> :
nullptr
) : (
dtype.equal(py::dtype::of<std::uint8_t>()) ? resample<agg::rgba8, false> :
dtype.equal(py::dtype::of<std::int8_t>()) ? resample<agg::rgba8, false> :
dtype.equal(py::dtype::of<std::uint16_t>()) ? resample<agg::rgba16, false> :
dtype.equal(py::dtype::of<std::int16_t>()) ? resample<agg::rgba16, false> :
dtype.equal(py::dtype::of<float>()) ? resample<agg::rgba32, false> :
dtype.equal(py::dtype::of<double>()) ? resample<agg::rgba64, false> :
nullptr)))
{
Py_BEGIN_ALLOW_THREADS
resampler(
input_array.data(), input_array.shape(1), input_array.shape(0),
Expand Down

0 comments on commit d01ef30

Please sign in to comment.