Skip to content

Commit

Permalink
Refactor adapter model (#260)
Browse files Browse the repository at this point in the history
This PR moves the `AdapterModel` (i.e., the connecting layer between
baybe surrogates and botorch models) to the surrogate package, where it
actually belongs. Additionally, the surrogates are equipped with a
`to_botorch` method that simplifies the model translation and can be
customized per subclass. In particular, GP surrogates are no longer
wrapped using the adapter but now expose their internal botorch model
instance directly.
  • Loading branch information
AdrianSosic authored Jun 6, 2024
2 parents 914363f + dcaad6e commit 27d5903
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `Surrogate` base class now exposes a `to_botorch` method

### Changed
- Passing an `Objective` to `Campaign` is now optional
- `GaussianProcessSurrogate` models are no longer wrapped when cast to BoTorch

### Removed
- Support for Python 3.9 removed due to new [BoTorch requirements](https://github.com/pytorch/botorch/pull/2293)
Expand Down
4 changes: 1 addition & 3 deletions baybe/acquisition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ def to_botorch(
"""Create the botorch-ready representation of the function."""
import botorch.acquisition as botorch_analytical_acqf

from baybe.acquisition._adapter import AdapterModel

acqf_cls = getattr(botorch_analytical_acqf, self.__class__.__name__)
params_dict = filter_attributes(object=self, callable_=acqf_cls.__init__)

additional_params = {
p: v
for p, v in {
"model": AdapterModel(surrogate),
"model": surrogate.to_botorch(),
"best_f": train_y.max().item(),
"X_baseline": to_tensor(train_x),
}.items()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Adapter for making BoTorch's acquisition functions work with BayBE models."""
"""Adapter functionality for making BayBE surrogates BoTorch-ready."""

from collections.abc import Callable
from typing import Any
Expand All @@ -19,7 +19,7 @@ class AdapterModel(Model):
surrogate model usable in conjunction with BoTorch acquisition functions.
Args:
surrogate: The internal surrogate model
surrogate: The internal surrogate model.
"""

def __init__(self, surrogate: Surrogate):
Expand Down
7 changes: 7 additions & 0 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from baybe.surrogates.utils import _prepare_inputs, _prepare_targets

if TYPE_CHECKING:
from botorch.models.model import Model
from torch import Tensor

# Define constants
Expand Down Expand Up @@ -55,6 +56,12 @@ class Surrogate(ABC, SerialMixin):
"""Class variable encoding whether or not the surrogate supports transfer
learning."""

def to_botorch(self) -> Model:
"""Create the botorch-ready representation of the model."""
from baybe.surrogates._adapter import AdapterModel

return AdapterModel(self)

def posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
"""Evaluate the surrogate model at the given candidate points.
Expand Down
6 changes: 6 additions & 0 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

if TYPE_CHECKING:
from botorch.models.model import Model
from torch import Tensor


Expand Down Expand Up @@ -59,6 +60,11 @@ def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate:
"""Create a Gaussian process surrogate from one of the defined presets."""
return make_gp_from_preset(preset)

def to_botorch(self) -> Model: # noqa: D102
# See base class.

return self._model

def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]:
# See base class.
posterior = self._model.posterior(candidates)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_imports(module: str):

WHITELISTS = {
"torch": [
"baybe.acquisition._adapter",
"baybe.acquisition.partial",
"baybe.surrogates._adapter",
"baybe.utils.botorch_wrapper",
"baybe.utils.torch",
],
Expand Down

0 comments on commit 27d5903

Please sign in to comment.