diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bf3f1989..671b77b77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,8 +5,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `Surrogate` base class now exposes a `to_botorch` method + ### Changed - Passing an `Objective` to `Campaign` is now optional +- `GaussianProcessSurrogate` models are no longer wrapped when cast to BoTorch ### Removed - Support for Python 3.9 removed due to new [BoTorch requirements](https://github.com/pytorch/botorch/pull/2293) diff --git a/baybe/acquisition/base.py b/baybe/acquisition/base.py index 68ce96824..98fdd01de 100644 --- a/baybe/acquisition/base.py +++ b/baybe/acquisition/base.py @@ -43,15 +43,13 @@ def to_botorch( """Create the botorch-ready representation of the function.""" import botorch.acquisition as botorch_analytical_acqf - from baybe.acquisition._adapter import AdapterModel - acqf_cls = getattr(botorch_analytical_acqf, self.__class__.__name__) params_dict = filter_attributes(object=self, callable_=acqf_cls.__init__) additional_params = { p: v for p, v in { - "model": AdapterModel(surrogate), + "model": surrogate.to_botorch(), "best_f": train_y.max().item(), "X_baseline": to_tensor(train_x), }.items() diff --git a/baybe/acquisition/_adapter.py b/baybe/surrogates/_adapter.py similarity index 91% rename from baybe/acquisition/_adapter.py rename to baybe/surrogates/_adapter.py index 5f922f304..6241ac37f 100644 --- a/baybe/acquisition/_adapter.py +++ b/baybe/surrogates/_adapter.py @@ -1,4 +1,4 @@ -"""Adapter for making BoTorch's acquisition functions work with BayBE models.""" +"""Adapter functionality for making BayBE surrogates BoTorch-ready.""" from collections.abc import Callable from typing import Any @@ -19,7 +19,7 @@ class AdapterModel(Model): surrogate model usable in conjunction with BoTorch acquisition functions. Args: - surrogate: The internal surrogate model + surrogate: The internal surrogate model. """ def __init__(self, surrogate: Surrogate): diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index da01e487d..999bd7c2f 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -25,6 +25,7 @@ from baybe.surrogates.utils import _prepare_inputs, _prepare_targets if TYPE_CHECKING: + from botorch.models.model import Model from torch import Tensor # Define constants @@ -55,6 +56,12 @@ class Surrogate(ABC, SerialMixin): """Class variable encoding whether or not the surrogate supports transfer learning.""" + def to_botorch(self) -> Model: + """Create the botorch-ready representation of the model.""" + from baybe.surrogates._adapter import AdapterModel + + return AdapterModel(self) + def posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: """Evaluate the surrogate model at the given candidate points. diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index bcb7eb361..f22591294 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -22,6 +22,7 @@ ) if TYPE_CHECKING: + from botorch.models.model import Model from torch import Tensor @@ -59,6 +60,11 @@ def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: """Create a Gaussian process surrogate from one of the defined presets.""" return make_gp_from_preset(preset) + def to_botorch(self) -> Model: # noqa: D102 + # See base class. + + return self._model + def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: # See base class. posterior = self._model.posterior(candidates) diff --git a/tests/test_imports.py b/tests/test_imports.py index 1ce6970c0..227410297 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -58,8 +58,8 @@ def test_imports(module: str): WHITELISTS = { "torch": [ - "baybe.acquisition._adapter", "baybe.acquisition.partial", + "baybe.surrogates._adapter", "baybe.utils.botorch_wrapper", "baybe.utils.torch", ],