diff --git a/src/core/python/transform_v.cpp b/src/core/python/transform_v.cpp index e950b784f..54411bc02 100644 --- a/src/core/python/transform_v.cpp +++ b/src/core/python/transform_v.cpp @@ -7,12 +7,13 @@ #include #include -template +template void bind_transform3(nb::module_ &m, const char *name) { MI_IMPORT_CORE_TYPES() using ScalarType = drjit::scalar_t; using NdMatrix = nb::ndarray, nb::c_contig, nb::device::cpu>; + using Ray2f = Ray, Spectrum>; auto trans3 = nb::class_(m, name, D(Transform)) .def(nb::init<>(), "Initialize with the identity matrix") @@ -52,12 +53,18 @@ void bind_transform3(nb::module_ &m, const char *name) { .def("__matmul__", [](const Transform3f &a, const Vector2f &b) { return a * b; }, nb::is_operator()) + .def("__matmul__", [](const Transform3f &a, const Ray2f &b) { + return a * b; + }, nb::is_operator()) .def("transform_affine", [](const Transform3f &a, const Point2f &b) { return a.transform_affine(b); }, "p"_a, D(Transform, transform_affine)) .def("transform_affine", [](const Transform3f &a, const Vector2f &b) { return a.transform_affine(b); }, "v"_a, D(Transform, transform_affine)) + .def("transform_affine", [](const Transform3f &a, const Ray2f &b) { + return a.transform_affine(b); + }, "ray"_a, D(Transform, transform_affine)) /// Chain transformations .def("translate", [](const Transform3f &t, const Point2f &v) { return Transform3f(t * Transform3f::translate(v)); @@ -191,23 +198,23 @@ MI_PY_EXPORT(Transform) { using ScalarSpectrum = scalar_spectrum_t; MI_PY_CHECK_ALIAS(Transform3f, "Transform3f") { - bind_transform3(m, "Transform3f"); + bind_transform3(m, "Transform3f"); } MI_PY_CHECK_ALIAS(Transform3d, "Transform3d") { - bind_transform3(m, "Transform3d"); + bind_transform3(m, "Transform3d"); } MI_PY_CHECK_ALIAS(ScalarTransform3f, "ScalarTransform3f") { if constexpr (dr::is_dynamic_v) { - bind_transform3(m, "ScalarTransform3f"); + bind_transform3(m, "ScalarTransform3f"); nb::implicitly_convertible(); } } MI_PY_CHECK_ALIAS(ScalarTransform3d, "ScalarTransform3d") { if constexpr (dr::is_dynamic_v) { - bind_transform3(m, "ScalarTransform3d"); + bind_transform3(m, "ScalarTransform3d"); nb::implicitly_convertible(); } }