From 6f018a9cd1b369bcb247e1d539968db8e48b2b3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20M=C3=BCller?= Date: Sat, 5 Aug 2023 07:33:35 +0200 Subject: [PATCH] Remove incorrect early exit of `Encoding::backward_impl` When there are no outputs, backward_impl should produce a zero gradient rather than not touching the gradient matrix at all. --- include/tiny-cuda-nn/encodings/empty.h | 2 +- include/tiny-cuda-nn/encodings/frequency.h | 2 +- include/tiny-cuda-nn/encodings/grid.h | 2 +- include/tiny-cuda-nn/encodings/identity.h | 2 +- include/tiny-cuda-nn/encodings/oneblob.h | 2 +- include/tiny-cuda-nn/encodings/spherical_harmonics.h | 2 +- include/tiny-cuda-nn/encodings/triangle_wave.h | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/tiny-cuda-nn/encodings/empty.h b/include/tiny-cuda-nn/encodings/empty.h index a2b21494..72fe5894 100644 --- a/include/tiny-cuda-nn/encodings/empty.h +++ b/include/tiny-cuda-nn/encodings/empty.h @@ -99,7 +99,7 @@ class EmptyEncoding : public Encoding { bool use_inference_params = false, GradientMode param_gradients_mode = GradientMode::Overwrite ) override { - if (!dL_dinput || padded_output_width() == 0) { + if (!dL_dinput) { return; } diff --git a/include/tiny-cuda-nn/encodings/frequency.h b/include/tiny-cuda-nn/encodings/frequency.h index 419e4813..90e04db1 100644 --- a/include/tiny-cuda-nn/encodings/frequency.h +++ b/include/tiny-cuda-nn/encodings/frequency.h @@ -150,7 +150,7 @@ class FrequencyEncoding : public Encoding { bool use_inference_params = false, GradientMode param_gradients_mode = GradientMode::Overwrite ) override { - if (!dL_dinput || padded_output_width() == 0) { + if (!dL_dinput) { return; } diff --git a/include/tiny-cuda-nn/encodings/grid.h b/include/tiny-cuda-nn/encodings/grid.h index 89fbd5db..cff67359 100644 --- a/include/tiny-cuda-nn/encodings/grid.h +++ b/include/tiny-cuda-nn/encodings/grid.h @@ -817,7 +817,7 @@ class GridEncodingTemplated : public GridEncoding { GradientMode param_gradients_mode = GradientMode::Overwrite ) override { const uint32_t num_elements = input.n(); - if ((!dL_dinput && param_gradients_mode == GradientMode::Ignore) || padded_output_width() == 0 || num_elements == 0) { + if ((!dL_dinput && param_gradients_mode == GradientMode::Ignore) || num_elements == 0) { return; } diff --git a/include/tiny-cuda-nn/encodings/identity.h b/include/tiny-cuda-nn/encodings/identity.h index f12c83aa..7518775d 100644 --- a/include/tiny-cuda-nn/encodings/identity.h +++ b/include/tiny-cuda-nn/encodings/identity.h @@ -126,7 +126,7 @@ class IdentityEncoding : public Encoding { bool use_inference_params = false, GradientMode param_gradients_mode = GradientMode::Overwrite ) override { - if (!dL_dinput || padded_output_width() == 0) { + if (!dL_dinput) { return; } diff --git a/include/tiny-cuda-nn/encodings/oneblob.h b/include/tiny-cuda-nn/encodings/oneblob.h index 3b461ea0..83127d7a 100644 --- a/include/tiny-cuda-nn/encodings/oneblob.h +++ b/include/tiny-cuda-nn/encodings/oneblob.h @@ -240,7 +240,7 @@ class OneBlobEncoding : public Encoding { bool use_inference_params = false, GradientMode param_gradients_mode = GradientMode::Overwrite ) override { - if (!dL_dinput || padded_output_width() == 0) { + if (!dL_dinput) { return; } diff --git a/include/tiny-cuda-nn/encodings/spherical_harmonics.h b/include/tiny-cuda-nn/encodings/spherical_harmonics.h index 66aca859..9350d5d2 100644 --- a/include/tiny-cuda-nn/encodings/spherical_harmonics.h +++ b/include/tiny-cuda-nn/encodings/spherical_harmonics.h @@ -151,7 +151,7 @@ class SphericalHarmonicsEncoding : public Encoding { bool use_inference_params = false, GradientMode param_gradients_mode = GradientMode::Overwrite ) override { - if (!dL_dinput || padded_output_width() == 0) { + if (!dL_dinput) { return; } diff --git a/include/tiny-cuda-nn/encodings/triangle_wave.h b/include/tiny-cuda-nn/encodings/triangle_wave.h index 7c3acb43..56c4ee4b 100644 --- a/include/tiny-cuda-nn/encodings/triangle_wave.h +++ b/include/tiny-cuda-nn/encodings/triangle_wave.h @@ -155,7 +155,7 @@ class TriangleWaveEncoding : public Encoding { bool use_inference_params = false, GradientMode param_gradients_mode = GradientMode::Overwrite ) override { - if (!dL_dinput || padded_output_width() == 0) { + if (!dL_dinput) { return; }