From 97900aaa6203f43ffba51eacd7649805d24ee549 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Sun, 5 Jan 2025 18:43:28 +0000 Subject: [PATCH] Add tests --- tests/unit/models/gpflow/test_sampler.py | 28 ++++++++++++++++++++++++ tests/unit/utils/test_misc.py | 26 ++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/tests/unit/models/gpflow/test_sampler.py b/tests/unit/models/gpflow/test_sampler.py index bad6a5ee57..608db07c22 100644 --- a/tests/unit/models/gpflow/test_sampler.py +++ b/tests/unit/models/gpflow/test_sampler.py @@ -285,6 +285,20 @@ def test_independent_reparametrization_sampler_reset_sampler(qmc: bool, qmc_skip npt.assert_array_less(1e-9, tf.abs(samples2 - samples1)) +@pytest.mark.parametrize("qmc", [True, False]) +@pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) +def test_independent_reparametrization_sampler_sample_ensures_positive_variance( + qmc: bool, dtype: tf.DType +) -> None: + model = QuadraticMeanAndRBFKernel(kernel_amplitude=tf.constant(0, dtype=dtype)) + sampler = IndependentReparametrizationSampler(100, model, qmc=qmc) + x = tf.constant([[1.0]], dtype=dtype) + variance = tf.math.reduce_variance(sampler.sample(x)) # default jitter + assert variance > (1e-7 if dtype is tf.float32 else 1e-17) + variance = tf.math.reduce_variance(sampler.sample(x, jitter=0.0)) # explicit jitter + assert variance > (1e-7 if dtype is tf.float32 else 1e-17) + + @pytest.mark.parametrize("qmc", [True, False]) @pytest.mark.parametrize("sample_size", [0, -2]) def test_batch_reparametrization_sampler_raises_for_invalid_sample_size( @@ -457,6 +471,20 @@ def test_batch_reparametrization_sampler_reset_sampler(qmc: bool, qmc_skip: bool npt.assert_array_less(1e-9, tf.abs(samples2 - samples1)) +@pytest.mark.parametrize("qmc", [True, False]) +@pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) +def test_batch_reparametrization_sampler_sample_ensures_positive_variance( + qmc: bool, dtype: tf.DType +) -> None: + model = QuadraticMeanAndRBFKernel(kernel_amplitude=tf.constant(0, dtype=dtype)) + sampler = BatchReparametrizationSampler(100, model, qmc=qmc) + x = tf.constant([[1.0]], dtype=dtype) + variance = tf.math.reduce_variance(sampler.sample(x)) # default jitter + assert variance > (1e-7 if dtype is tf.float32 else 1e-17) + variance = tf.math.reduce_variance(sampler.sample(x, jitter=0.0)) # explicit jitter + assert variance > (1e-7 if dtype is tf.float32 else 1e-17) + + @pytest.mark.parametrize("num_features", [0, -2]) def test_rff_trajectory_sampler_raises_for_invalid_number_of_features( num_features: int, diff --git a/tests/unit/utils/test_misc.py b/tests/unit/utils/test_misc.py index 528f7aabf6..c9208a37ef 100644 --- a/tests/unit/utils/test_misc.py +++ b/tests/unit/utils/test_misc.py @@ -29,6 +29,7 @@ LocalizedTag, Ok, Timer, + ensure_positive, flatten_leading_dims, get_value_for_tag, jit, @@ -222,3 +223,28 @@ def test_flatten_leading_dims_invalid_output_dims(output_dims: int) -> None: x_old = tf.random.uniform([2, 3, 4, 5]) # [2, 3, 4, 5] with pytest.raises(TF_DEBUGGING_ERROR_TYPES): flatten_leading_dims(x_old, output_dims=output_dims) + + +@pytest.mark.parametrize( + "t, expected", + [ + ( + tf.constant(0, dtype=tf.float32), + tf.constant(1e-6, dtype=tf.float32), + ), + ( + tf.constant(0, dtype=tf.float64), + tf.constant(1e-16, dtype=tf.float64), + ), + ( + tf.constant([[-1.0, 0.0], [1e-7, 1.0]], dtype=tf.float32), + tf.constant([[1e-6, 1e-6], [1e-6, 1.0]], dtype=tf.float32), + ), + ( + tf.constant([[-1.0, 0.0], [1e-7, 1.0]], dtype=tf.float64), + tf.constant([[1e-16, 1e-16], [1e-7, 1.0]], dtype=tf.float64), + ), + ], +) +def test_ensure_positive(t: TensorType, expected: TensorType) -> None: + npt.assert_array_equal(ensure_positive(t), expected)