Skip to content

Commit

Permalink
Handle strided input array in Bitmap constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
njroussel committed Apr 14, 2023
1 parent 1a9bde5 commit 5b784db
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 4 deletions.
51 changes: 47 additions & 4 deletions src/core/python/bitmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <mitsuba/core/filesystem.h>
#include <mitsuba/core/stream.h>
#include <mitsuba/core/mstream.h>
#include <pybind11/numpy.h>
#include <mitsuba/python/python.h>

MI_PY_EXPORT(Bitmap) {
Expand Down Expand Up @@ -281,11 +280,55 @@ MI_PY_EXPORT(Bitmap) {
if (!pixel_format_.is_none())
pixel_format = pixel_format_.cast<Bitmap::PixelFormat>();

obj = py::array::ensure(obj, py::array::c_style);
Vector2u size(shape[1].cast<size_t>(), shape[0].cast<size_t>());
auto bitmap = new Bitmap(pixel_format, component_format, size,
Bitmap* bitmap = new Bitmap(pixel_format, component_format, size,
channel_count, channel_names);
memcpy(bitmap->data(), ptr, bitmap->buffer_size());

bool is_contiguous = true;
if (interface.contains("strides") && !interface["strides"].is_none()) {
auto strides = interface["strides"].template cast<py::tuple>();

// Stride information might be given although it's contiguous
size_t bytes_per_value = bitmap->bytes_per_pixel() / bitmap->channel_count();
int64_t current_stride = bytes_per_value;
for (size_t i = ndim; i > 0; --i) {
if (current_stride != strides[i - 1].cast<int64_t>()) {
is_contiguous = false;
break;
}
current_stride *= shape[i - 1].cast<size_t>();
}

// Need to shuffle memory
if (!is_contiguous) {
size_t num_values = size.x() * size.y();
if (ndim == 3)
num_values *= shape[2].cast<size_t>();

std::unique_ptr<size_t[]> shape_cum = std::make_unique<size_t[]>(ndim);
shape_cum[ndim - 1] = 1;
for (size_t j = ndim - 1; j > 0; --j) {
size_t shape_j = shape[j].cast<size_t>();
shape_cum[j - 1] = shape_j * shape_cum[j];
}

for (size_t i = 0; i < num_values; ++i) {
unsigned char* src = (unsigned char*) ptr;
for (size_t j = ndim; j > 0; --j) {
int64_t stride_j = strides[j - 1].cast<int64_t>();
size_t shape_j = shape[j - 1].cast<size_t>();
int64_t offset_j = ((i / shape_cum[j - 1]) % shape_j) * stride_j;
src += offset_j;
}
unsigned char *dest = (unsigned char *) bitmap->data() + i * bytes_per_value;
memcpy(dest, src, bytes_per_value);
}
}
}

if (is_contiguous)
memcpy(bitmap->data(), ptr, bitmap->buffer_size());

return bitmap;
}),
"array"_a,
Expand Down
18 changes: 18 additions & 0 deletions src/core/tests/test_bitmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,24 @@ def test_construct_from_array(variants_all_rgb):
assert dr.allclose(b_np, np.array(b2))


def test_construct_from_non_contiguous_array(variants_all_rgb):
flat_arr = np.array([
[[0, 1], [3, 4], [6, 7]],
[[1, 2], [4, 5], [7, 8]],
[[2, 3], [5, 6], [8, 9]],
[[3, 4], [6, 7], [9, 10]],
], dtype=np.float32)

for i in range(2):
strided = np.flip(flat_arr, axis=i)

# Test assumes that `np.flip` will create a strided array
assert strided.strides[i] < 0

b = mi.Bitmap(strided)
assert dr.allclose(strided, np.array(b))


def test_construct_from_int8_array(variants_all_rgb):
# test uint8
b_np = np.reshape(np.arange(16), (4, 4)).astype(np.uint8)
Expand Down

0 comments on commit 5b784db

Please sign in to comment.