Skip to content

Commit

Permalink
Really leave them alone
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jan 6, 2025
1 parent 51adc46 commit bfccfd8
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 18 deletions.
14 changes: 0 additions & 14 deletions tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,20 +471,6 @@ def test_batch_reparametrization_sampler_reset_sampler(qmc: bool, qmc_skip: bool
npt.assert_array_less(1e-9, tf.abs(samples2 - samples1))


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("dtype", [tf.float32, tf.float64])
def test_batch_reparametrization_sampler_sample_ensures_positive_variance(
qmc: bool, dtype: tf.DType
) -> None:
model = QuadraticMeanAndRBFKernel(kernel_amplitude=tf.constant(0, dtype=dtype))
sampler = BatchReparametrizationSampler(100, model, qmc=qmc)
x = tf.constant([[1.0]], dtype=dtype)
variance = tf.math.reduce_variance(sampler.sample(x)) # default jitter
assert variance > 1e-7
variance = tf.math.reduce_variance(sampler.sample(x, jitter=0.0)) # explicit jitter
assert variance > (1e-7 if dtype is tf.float32 else 1e-17)


@pytest.mark.parametrize("num_features", [0, -2])
def test_rff_trajectory_sampler_raises_for_invalid_number_of_features(
num_features: int,
Expand Down
3 changes: 1 addition & 2 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ def sample_eps() -> tf.Tensor:
)

identity = tf.eye(batch_size, dtype=cov.dtype) # [B, B]
cov = ensure_positive(cov + jitter * identity)
cov_cholesky = tf.linalg.cholesky(cov) # [..., L, B, B]
cov_cholesky = tf.linalg.cholesky(cov + jitter * identity) # [..., L, B, B]

variance_contribution = cov_cholesky @ self._eps # [..., L, B, S]

Expand Down
3 changes: 1 addition & 2 deletions trieste/models/gpflux/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from ...types import TensorType
from ...utils import DEFAULTS, flatten_leading_dims
from ...utils.misc import ensure_positive
from ..interfaces import (
ReparametrizationSampler,
TrajectoryFunction,
Expand Down Expand Up @@ -109,7 +108,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
continue

mean, var = layer.predict(samples, full_cov=False, full_output_cov=False)
var = ensure_positive(var + jitter)
var = var + jitter

if not self._initialized:
self._eps_list[i].assign(
Expand Down

0 comments on commit bfccfd8

Please sign in to comment.