Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jan 5, 2025
1 parent 7ef0492 commit 97900aa
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,20 @@ def test_independent_reparametrization_sampler_reset_sampler(qmc: bool, qmc_skip
npt.assert_array_less(1e-9, tf.abs(samples2 - samples1))


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("dtype", [tf.float32, tf.float64])
def test_independent_reparametrization_sampler_sample_ensures_positive_variance(
qmc: bool, dtype: tf.DType
) -> None:
model = QuadraticMeanAndRBFKernel(kernel_amplitude=tf.constant(0, dtype=dtype))
sampler = IndependentReparametrizationSampler(100, model, qmc=qmc)
x = tf.constant([[1.0]], dtype=dtype)
variance = tf.math.reduce_variance(sampler.sample(x)) # default jitter
assert variance > (1e-7 if dtype is tf.float32 else 1e-17)
variance = tf.math.reduce_variance(sampler.sample(x, jitter=0.0)) # explicit jitter
assert variance > (1e-7 if dtype is tf.float32 else 1e-17)


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("sample_size", [0, -2])
def test_batch_reparametrization_sampler_raises_for_invalid_sample_size(
Expand Down Expand Up @@ -457,6 +471,20 @@ def test_batch_reparametrization_sampler_reset_sampler(qmc: bool, qmc_skip: bool
npt.assert_array_less(1e-9, tf.abs(samples2 - samples1))


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("dtype", [tf.float32, tf.float64])
def test_batch_reparametrization_sampler_sample_ensures_positive_variance(
qmc: bool, dtype: tf.DType
) -> None:
model = QuadraticMeanAndRBFKernel(kernel_amplitude=tf.constant(0, dtype=dtype))
sampler = BatchReparametrizationSampler(100, model, qmc=qmc)
x = tf.constant([[1.0]], dtype=dtype)
variance = tf.math.reduce_variance(sampler.sample(x)) # default jitter
assert variance > (1e-7 if dtype is tf.float32 else 1e-17)
variance = tf.math.reduce_variance(sampler.sample(x, jitter=0.0)) # explicit jitter
assert variance > (1e-7 if dtype is tf.float32 else 1e-17)


@pytest.mark.parametrize("num_features", [0, -2])
def test_rff_trajectory_sampler_raises_for_invalid_number_of_features(
num_features: int,
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LocalizedTag,
Ok,
Timer,
ensure_positive,
flatten_leading_dims,
get_value_for_tag,
jit,
Expand Down Expand Up @@ -222,3 +223,28 @@ def test_flatten_leading_dims_invalid_output_dims(output_dims: int) -> None:
x_old = tf.random.uniform([2, 3, 4, 5]) # [2, 3, 4, 5]
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
flatten_leading_dims(x_old, output_dims=output_dims)


@pytest.mark.parametrize(
"t, expected",
[
(
tf.constant(0, dtype=tf.float32),
tf.constant(1e-6, dtype=tf.float32),
),
(
tf.constant(0, dtype=tf.float64),
tf.constant(1e-16, dtype=tf.float64),
),
(
tf.constant([[-1.0, 0.0], [1e-7, 1.0]], dtype=tf.float32),
tf.constant([[1e-6, 1e-6], [1e-6, 1.0]], dtype=tf.float32),
),
(
tf.constant([[-1.0, 0.0], [1e-7, 1.0]], dtype=tf.float64),
tf.constant([[1e-16, 1e-16], [1e-7, 1.0]], dtype=tf.float64),
),
],
)
def test_ensure_positive(t: TensorType, expected: TensorType) -> None:
npt.assert_array_equal(ensure_positive(t), expected)

0 comments on commit 97900aa

Please sign in to comment.