Skip to content

Commit

Permalink
avoid use of python bool in qmc_normal_samples which would otherwise …
Browse files Browse the repository at this point in the history
…prevent it being saved as a tf.module
  • Loading branch information
chris committed Oct 27, 2023
1 parent aa94c86 commit 9ff8b5a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
22 changes: 22 additions & 0 deletions tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import math
import unittest
from pathlib import Path
from typing import Any, Callable, List, Tuple, Type
from unittest.mock import MagicMock

Expand Down Expand Up @@ -940,6 +941,27 @@ def test_qmc_samples_shapes(num_samples: int, n_sample_dim: int) -> None:
assert samples.shape == expected_samples_shape


def test_qmc_samples__save_as_tf_function(tmp_path: Path) -> None:
def get_samples():
return qmc_normal_samples(
num_samples=tf.constant(5),
n_sample_dim=tf.constant(2),
)

module = tf.Module()
module.get_samples = tf.function(
get_samples,
input_signature=[],
autograph=False,
)

save_path = Path(tmp_path / "qmc_sampler")
tf.saved_model.save(module, str(save_path))
loaded_module = tf.saved_model.load(str(save_path))
samples = loaded_module.get_samples()
assert samples.shape == (5, 2)


@pytest.mark.parametrize(
("num_samples", "n_sample_dim", "skip", "expected_error_type"),
(
Expand Down
29 changes: 16 additions & 13 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,24 @@ def qmc_normal_samples(
sample has dimension `n_sample_dim`.
"""

if num_samples == 0 or n_sample_dim == 0:
return tf.zeros(shape=(num_samples, n_sample_dim), dtype=tf.float64)

sobol_samples = tf.math.sobol_sample(
dim=n_sample_dim,
num_results=num_samples,
dtype=tf.float64,
skip=skip,
)
def _qmc_normal_samples() -> tf.Tensor:
sobol_samples = tf.math.sobol_sample(
dim=n_sample_dim,
num_results=num_samples,
dtype=tf.float64,
skip=skip,
)
dist = tfp.distributions.Normal(
loc=tf.constant(0.0, dtype=tf.float64),
scale=tf.constant(1.0, dtype=tf.float64),
)
return dist.quantile(sobol_samples)

dist = tfp.distributions.Normal(
loc=tf.constant(0.0, dtype=tf.float64),
scale=tf.constant(1.0, dtype=tf.float64),
normal_samples = tf.cond(
tf.logical_or(num_samples == 0, n_sample_dim == 0),
true_fn=lambda: tf.zeros(shape=(num_samples, n_sample_dim), dtype=tf.float64),
false_fn=_qmc_normal_samples,
)
normal_samples = dist.quantile(sobol_samples)
return normal_samples


Expand Down

0 comments on commit 9ff8b5a

Please sign in to comment.