Skip to content

Commit

Permalink
Include custom ScaleKernel (#221)
Browse files Browse the repository at this point in the history
This PR introduces custom ScaleKernels.
  • Loading branch information
AdrianSosic authored May 14, 2024
2 parents fad3324 + 80ae210 commit 09243ce
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 67 deletions.
7 changes: 3 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `mypy` for search space and objectives
- Class hierarchy for objectives
- Deserialization is now also possible from optional class name abbreviations
- `Kernel` base class allowing to specify kernels
- `MaternKernel` class can be chosen for GP surrogates
- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives, priors
and acquisition functions
- `Kernel`, `MaternKernel`, and `ScaleKernel` classes for specifying kernels
- `hypothesis` strategies and roundtrip test for kernels, constraints, objectives,
priors and acquisition functions
- New acquisition functions: `qSR`, `qNEI`, `LogEI`, `qLogEI`, `qLogNEI`
- Serialization user guide
- Basic deserialization tests using different class type specifiers
Expand Down
7 changes: 5 additions & 2 deletions baybe/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Kernels for Gaussian process surrogate models."""

from baybe.kernels.basic import MaternKernel
from baybe.kernels.basic import MaternKernel, ScaleKernel

__all__ = ["MaternKernel"]
__all__ = [
"MaternKernel",
"ScaleKernel",
]
75 changes: 61 additions & 14 deletions baybe/kernels/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Base classes for all kernels."""

from __future__ import annotations

from abc import ABC
from typing import Optional
from typing import TYPE_CHECKING, Optional

from attrs import define, field
from attrs import define

from baybe.kernels.priors.base import Prior
from baybe.serialization.core import (
Expand All @@ -12,31 +14,76 @@
unstructure_base,
)
from baybe.serialization.mixin import SerialMixin
from baybe.utils.basic import filter_attributes
from baybe.utils.basic import filter_attributes, get_baseclasses

if TYPE_CHECKING:
import torch


@define(frozen=True)
class Kernel(ABC, SerialMixin):
"""Abstract base class for all kernels."""

lengthscale_prior: Optional[Prior] = field(default=None, kw_only=True)
"""An optional prior on the kernel lengthscale."""

def to_gpytorch(self, *args, **kwargs):
def to_gpytorch(
self,
*,
ard_num_dims: Optional[int] = None,
batch_shape: Optional[torch.Size] = None,
active_dims: Optional[tuple[int, ...]] = None,
):
"""Create the gpytorch representation of the kernel."""
import gpytorch.kernels

# Fetch the necessary gpytorch constructor parameters of the kernel.
# NOTE: In gpytorch, some attributes (like the kernel lengthscale) are handled
# via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to
# just check the fields of the actual class, but also those of the base class.
kernel_cls = getattr(gpytorch.kernels, self.__class__.__name__)
fields_dict = filter_attributes(object=self, callable_=kernel_cls.__init__)
base_classes = get_baseclasses(kernel_cls, abstract=True)
fields_dict = {}
for cls in [kernel_cls, *base_classes]:
fields_dict.update(filter_attributes(object=self, callable_=cls.__init__))

# Convert specified priors to gpytorch, if provided
prior_dict = {
key: value.to_gpytorch()
for key, value in fields_dict.items()
if isinstance(value, Prior)
}

# Convert specified inner kernels to gpytorch, if provided
kernel_dict = {
key: value.to_gpytorch(
ard_num_dims=ard_num_dims,
batch_shape=batch_shape,
active_dims=active_dims,
)
for key, value in fields_dict.items()
if isinstance(value, Kernel)
}

# Create the kernel with all its inner gpytorch objects
fields_dict.update(kernel_dict)
fields_dict.update(prior_dict)
gpytorch_kernel = kernel_cls(**fields_dict)

# If the kernel has a lengthscale, set its initial value
if kernel_cls.has_lengthscale:
import torch

# If a lengthscale prior was chosen, we manually add it to the dictionary
if self.lengthscale_prior is not None:
fields_dict["lengthscale_prior"] = self.lengthscale_prior.to_gpytorch()
from baybe.utils.torch import DTypeFloatTorch

# Update kwargs to contain class-specific attributes
kwargs.update(fields_dict)
# We can ignore mypy here and simply assume that the corresponding BayBE
# kernel class has the necessary lengthscale attribute defined. This is
# safer than using a `hasattr` check in the above if-condition since for
# the latter the code would silently fail when forgetting to add the
# attribute to a new kernel class / misspelling it.
if (initial_value := self.lengthscale_initial_value) is not None: # type: ignore[attr-defined]
gpytorch_kernel.lengthscale = torch.tensor(
initial_value, dtype=DTypeFloatTorch
)

return kernel_cls(*args, **kwargs)
return gpytorch_kernel


# Register de-/serialization hooks
Expand Down
76 changes: 49 additions & 27 deletions baybe/kernels/basic.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,67 @@
"""Collection of kernels."""

from fractions import Fraction
from typing import Union
from typing import Optional

from attrs import define, field
from attrs.validators import in_
from attrs.converters import optional as optional_c
from attrs.validators import in_, instance_of
from attrs.validators import optional as optional_v

from baybe.kernels.base import Kernel


def _convert_fraction(value: Union[str, float, Fraction], /) -> float:
"""Convert the provided value into a float.
Args:
value: The parameter that should be converted.
Returns:
The float representation of the given input.
Raises:
ValueError: If the input was provided as string but could not be interpreted as
fraction.
"""
if isinstance(value, str):
try:
value = Fraction(value)
except ValueError as err:
raise ValueError(
f"The provided input '{value}' could not be interpreted as a fraction."
) from err
return float(value)
from baybe.kernels.priors.base import Prior
from baybe.utils.conversion import fraction_to_float
from baybe.utils.validation import finite_float


@define(frozen=True)
class MaternKernel(Kernel):
"""A Matern kernel using a smoothness parameter."""

nu: float = field(
converter=_convert_fraction, validator=in_([0.5, 1.5, 2.5]), default=2.5
converter=fraction_to_float, validator=in_([0.5, 1.5, 2.5]), default=2.5
)
"""A smoothness parameter.
Only takes the values 0.5, 1.5 or 2.5. Larger values yield smoother interpolations.
"""

lengthscale_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the kernel lengthscale."""

lengthscale_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the kernel lengthscale."""


@define(frozen=True)
class ScaleKernel(Kernel):
"""A kernel for decorating existing kernels with an outputscale."""

base_kernel: Kernel = field(validator=instance_of(Kernel))
"""The base kernel that is being decorated."""

outputscale_prior: Optional[Prior] = field(
default=None, validator=optional_v(instance_of(Prior))
)
"""An optional prior on the output scale."""

outputscale_initial_value: Optional[float] = field(
default=None, converter=optional_c(float), validator=optional_v(finite_float)
)
"""An optional initial value for the output scale."""

def to_gpytorch(self, *args, **kwargs): # noqa: D102
# See base class.
import torch

from baybe.utils.torch import DTypeFloatTorch

gpytorch_kernel = super().to_gpytorch(*args, **kwargs)
if (initial_value := self.outputscale_initial_value) is not None:
gpytorch_kernel.outputscale = torch.tensor(
initial_value, dtype=DTypeFloatTorch
)
return gpytorch_kernel
24 changes: 10 additions & 14 deletions baybe/surrogates/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from attr import define, field

from baybe.kernels import MaternKernel
from baybe.kernels import MaternKernel, ScaleKernel
from baybe.kernels.base import Kernel
from baybe.kernels.priors import GammaPrior
from baybe.searchspace import SearchSpace
Expand Down Expand Up @@ -108,25 +108,21 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No

# If no kernel is provided, we construct one from our priors
if self.kernel is None:
self.kernel = MaternKernel(lengthscale_prior=lengthscale_prior[0])
self.kernel = ScaleKernel(
base_kernel=MaternKernel(
lengthscale_prior=lengthscale_prior[0],
lengthscale_initial_value=lengthscale_prior[1],
),
outputscale_prior=outputscale_prior[0],
outputscale_initial_value=outputscale_prior[1],
)

# define the covariance module for the numeric dimensions
gpytorch_kernel = self.kernel.to_gpytorch(
base_covar_module = self.kernel.to_gpytorch(
ard_num_dims=train_x.shape[-1] - n_task_params,
active_dims=numeric_idxs,
batch_shape=batch_shape,
)
base_covar_module = gpytorch.kernels.ScaleKernel(
gpytorch_kernel,
batch_shape=batch_shape,
outputscale_prior=outputscale_prior[0].to_gpytorch(),
)
if outputscale_prior[1] is not None:
base_covar_module.outputscale = torch.tensor([outputscale_prior[1]])
if lengthscale_prior[1] is not None:
base_covar_module.base_kernel.lengthscale = torch.tensor(
[lengthscale_prior[1]]
)

# create GP covariance
if task_idx is None:
Expand Down
31 changes: 31 additions & 0 deletions baybe/utils/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,37 @@ def get_subclasses(cls: _C, recursive: bool = True, abstract: bool = False) -> l
return subclasses


def get_baseclasses(
cls: type,
recursive: bool = True,
abstract: bool = False,
) -> list[type]:
"""Return a list of base classes for the given class.
Args:
cls: The class to retrieve base classes for.
recursive: If ``True``, indirect base classes (i.e., base classes of base
classes) are included.
abstract: If `True`, abstract base classes are included.
Returns:
A list of base classes for the given class.
"""
from baybe.utils.boolean import is_abstract

classes = []

for baseclass in cls.__bases__:
if baseclass not in classes:
if abstract or not is_abstract(baseclass):
classes.append(baseclass)

if recursive:
classes.extend(get_baseclasses(baseclass, abstract=abstract))

return classes


def set_random_seed(seed: int):
"""Set the global random seed.
Expand Down
27 changes: 27 additions & 0 deletions baybe/utils/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Conversion utilities."""

from fractions import Fraction
from typing import Union


def fraction_to_float(value: Union[str, float, Fraction], /) -> float:
"""Convert the provided input representing a fraction into a float.
Args:
value: The input to be converted.
Returns:
The float representation of the given input.
Raises:
ValueError: If the input was provided as string but could not be interpreted as
fraction.
"""
if isinstance(value, str):
try:
value = Fraction(value)
except ValueError as err:
raise ValueError(
f"The provided input '{value}' could not be interpreted as a fraction."
) from err
return float(value)
25 changes: 24 additions & 1 deletion tests/hypothesis_strategies/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,36 @@

import hypothesis.strategies as st

from baybe.kernels import MaternKernel
from baybe.kernels import MaternKernel, ScaleKernel

from ..hypothesis_strategies.basic import finite_floats
from ..hypothesis_strategies.priors import priors

matern_kernels = st.builds(
MaternKernel,
nu=st.sampled_from((0.5, 1.5, 2.5)),
lengthscale_prior=st.one_of(st.none(), priors),
lengthscale_initial_value=st.one_of(st.none(), finite_floats()),
)
"""A strategy that generates Matern kernels."""


base_kernels = st.one_of([matern_kernels])
"""A strategy that generates base kernels to be used within more complex kernels."""


@st.composite
def kernels(draw: st.DrawFn):
"""Generate :class:`baybe.kernels.basic.Kernel`."""
base_kernel = draw(base_kernels)
add_scale = draw(st.booleans())
if add_scale:
return ScaleKernel(
base_kernel=base_kernel,
outputscale_prior=draw(st.one_of(st.none(), priors)),
outputscale_initial_value=draw(
st.one_of(st.none(), finite_floats()),
),
)
else:
return base_kernel
10 changes: 5 additions & 5 deletions tests/serialization/test_kernel_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from hypothesis import given

from baybe.kernels import MaternKernel
from tests.hypothesis_strategies.kernels import matern_kernels
from baybe.kernels.base import Kernel
from tests.hypothesis_strategies.kernels import kernels


@given(matern_kernels)
def test_matern_kernel_roundtrip(kernel: MaternKernel):
@given(kernels())
def test_kernel_roundtrip(kernel: Kernel):
string = kernel.to_json()
kernel2 = MaternKernel.from_json(string)
kernel2 = Kernel.from_json(string)
assert kernel == kernel2, (kernel, kernel2)
Loading

0 comments on commit 09243ce

Please sign in to comment.