From 96b76fc8105a8372126282ab03a4c0959c26aec9 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 3 Jun 2024 15:26:44 +0200 Subject: [PATCH 1/6] Move adapter module to surrogate package --- baybe/{acquisition => surrogates}/_adapter.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename baybe/{acquisition => surrogates}/_adapter.py (100%) diff --git a/baybe/acquisition/_adapter.py b/baybe/surrogates/_adapter.py similarity index 100% rename from baybe/acquisition/_adapter.py rename to baybe/surrogates/_adapter.py From c24d8ee40d7ab3356308e291fb98552d6e62db3a Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 3 Jun 2024 15:33:36 +0200 Subject: [PATCH 2/6] Add to_botorch method to surrogate base class --- baybe/acquisition/base.py | 4 +--- baybe/surrogates/base.py | 7 +++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/baybe/acquisition/base.py b/baybe/acquisition/base.py index 68ce96824..98fdd01de 100644 --- a/baybe/acquisition/base.py +++ b/baybe/acquisition/base.py @@ -43,15 +43,13 @@ def to_botorch( """Create the botorch-ready representation of the function.""" import botorch.acquisition as botorch_analytical_acqf - from baybe.acquisition._adapter import AdapterModel - acqf_cls = getattr(botorch_analytical_acqf, self.__class__.__name__) params_dict = filter_attributes(object=self, callable_=acqf_cls.__init__) additional_params = { p: v for p, v in { - "model": AdapterModel(surrogate), + "model": surrogate.to_botorch(), "best_f": train_y.max().item(), "X_baseline": to_tensor(train_x), }.items() diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index da01e487d..999bd7c2f 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -25,6 +25,7 @@ from baybe.surrogates.utils import _prepare_inputs, _prepare_targets if TYPE_CHECKING: + from botorch.models.model import Model from torch import Tensor # Define constants @@ -55,6 +56,12 @@ class Surrogate(ABC, SerialMixin): """Class variable encoding whether or not the surrogate supports transfer learning.""" + def to_botorch(self) -> Model: + """Create the botorch-ready representation of the model.""" + from baybe.surrogates._adapter import AdapterModel + + return AdapterModel(self) + def posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: """Evaluate the surrogate model at the given candidate points. From e7737c8471eeea34e8a7b057e4f6c6530acd6de0 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 3 Jun 2024 15:35:48 +0200 Subject: [PATCH 3/6] Avoid model wrapping for GPs --- baybe/surrogates/gaussian_process/core.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index bcb7eb361..f22591294 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -22,6 +22,7 @@ ) if TYPE_CHECKING: + from botorch.models.model import Model from torch import Tensor @@ -59,6 +60,11 @@ def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: """Create a Gaussian process surrogate from one of the defined presets.""" return make_gp_from_preset(preset) + def to_botorch(self) -> Model: # noqa: D102 + # See base class. + + return self._model + def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: # See base class. posterior = self._model.posterior(candidates) From bd014567da9fde6d22789c69fce3ae223efc70b0 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 3 Jun 2024 15:55:36 +0200 Subject: [PATCH 4/6] Adjust module docstring --- baybe/surrogates/_adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/baybe/surrogates/_adapter.py b/baybe/surrogates/_adapter.py index 5f922f304..6241ac37f 100644 --- a/baybe/surrogates/_adapter.py +++ b/baybe/surrogates/_adapter.py @@ -1,4 +1,4 @@ -"""Adapter for making BoTorch's acquisition functions work with BayBE models.""" +"""Adapter functionality for making BayBE surrogates BoTorch-ready.""" from collections.abc import Callable from typing import Any @@ -19,7 +19,7 @@ class AdapterModel(Model): surrogate model usable in conjunction with BoTorch acquisition functions. Args: - surrogate: The internal surrogate model + surrogate: The internal surrogate model. """ def __init__(self, surrogate: Surrogate): From b5266b2376ad216b450576a29e1efdd652e2cb93 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 3 Jun 2024 15:55:53 +0200 Subject: [PATCH 5/6] Fix import whitelist --- tests/test_imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_imports.py b/tests/test_imports.py index 1ce6970c0..227410297 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -58,8 +58,8 @@ def test_imports(module: str): WHITELISTS = { "torch": [ - "baybe.acquisition._adapter", "baybe.acquisition.partial", + "baybe.surrogates._adapter", "baybe.utils.botorch_wrapper", "baybe.utils.torch", ], From dcaad6e423e8b4ae54576ae05c46e7cb86c312b9 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Thu, 6 Jun 2024 08:23:19 +0200 Subject: [PATCH 6/6] Update CHANGELOG.md --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bf3f1989..671b77b77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,8 +5,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `Surrogate` base class now exposes a `to_botorch` method + ### Changed - Passing an `Objective` to `Campaign` is now optional +- `GaussianProcessSurrogate` models are no longer wrapped when cast to BoTorch ### Removed - Support for Python 3.9 removed due to new [BoTorch requirements](https://github.com/pytorch/botorch/pull/2293)