-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR introduces custom ScaleKernels.
- Loading branch information
Showing
10 changed files
with
235 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
"""Kernels for Gaussian process surrogate models.""" | ||
|
||
from baybe.kernels.basic import MaternKernel | ||
from baybe.kernels.basic import MaternKernel, ScaleKernel | ||
|
||
__all__ = ["MaternKernel"] | ||
__all__ = [ | ||
"MaternKernel", | ||
"ScaleKernel", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,67 @@ | ||
"""Collection of kernels.""" | ||
|
||
from fractions import Fraction | ||
from typing import Union | ||
from typing import Optional | ||
|
||
from attrs import define, field | ||
from attrs.validators import in_ | ||
from attrs.converters import optional as optional_c | ||
from attrs.validators import in_, instance_of | ||
from attrs.validators import optional as optional_v | ||
|
||
from baybe.kernels.base import Kernel | ||
|
||
|
||
def _convert_fraction(value: Union[str, float, Fraction], /) -> float: | ||
"""Convert the provided value into a float. | ||
Args: | ||
value: The parameter that should be converted. | ||
Returns: | ||
The float representation of the given input. | ||
Raises: | ||
ValueError: If the input was provided as string but could not be interpreted as | ||
fraction. | ||
""" | ||
if isinstance(value, str): | ||
try: | ||
value = Fraction(value) | ||
except ValueError as err: | ||
raise ValueError( | ||
f"The provided input '{value}' could not be interpreted as a fraction." | ||
) from err | ||
return float(value) | ||
from baybe.kernels.priors.base import Prior | ||
from baybe.utils.conversion import fraction_to_float | ||
from baybe.utils.validation import finite_float | ||
|
||
|
||
@define(frozen=True) | ||
class MaternKernel(Kernel): | ||
"""A Matern kernel using a smoothness parameter.""" | ||
|
||
nu: float = field( | ||
converter=_convert_fraction, validator=in_([0.5, 1.5, 2.5]), default=2.5 | ||
converter=fraction_to_float, validator=in_([0.5, 1.5, 2.5]), default=2.5 | ||
) | ||
"""A smoothness parameter. | ||
Only takes the values 0.5, 1.5 or 2.5. Larger values yield smoother interpolations. | ||
""" | ||
|
||
lengthscale_prior: Optional[Prior] = field( | ||
default=None, validator=optional_v(instance_of(Prior)) | ||
) | ||
"""An optional prior on the kernel lengthscale.""" | ||
|
||
lengthscale_initial_value: Optional[float] = field( | ||
default=None, converter=optional_c(float), validator=optional_v(finite_float) | ||
) | ||
"""An optional initial value for the kernel lengthscale.""" | ||
|
||
|
||
@define(frozen=True) | ||
class ScaleKernel(Kernel): | ||
"""A kernel for decorating existing kernels with an outputscale.""" | ||
|
||
base_kernel: Kernel = field(validator=instance_of(Kernel)) | ||
"""The base kernel that is being decorated.""" | ||
|
||
outputscale_prior: Optional[Prior] = field( | ||
default=None, validator=optional_v(instance_of(Prior)) | ||
) | ||
"""An optional prior on the output scale.""" | ||
|
||
outputscale_initial_value: Optional[float] = field( | ||
default=None, converter=optional_c(float), validator=optional_v(finite_float) | ||
) | ||
"""An optional initial value for the output scale.""" | ||
|
||
def to_gpytorch(self, *args, **kwargs): # noqa: D102 | ||
# See base class. | ||
import torch | ||
|
||
from baybe.utils.torch import DTypeFloatTorch | ||
|
||
gpytorch_kernel = super().to_gpytorch(*args, **kwargs) | ||
if (initial_value := self.outputscale_initial_value) is not None: | ||
gpytorch_kernel.outputscale = torch.tensor( | ||
initial_value, dtype=DTypeFloatTorch | ||
) | ||
return gpytorch_kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
"""Conversion utilities.""" | ||
|
||
from fractions import Fraction | ||
from typing import Union | ||
|
||
|
||
def fraction_to_float(value: Union[str, float, Fraction], /) -> float: | ||
"""Convert the provided input representing a fraction into a float. | ||
Args: | ||
value: The input to be converted. | ||
Returns: | ||
The float representation of the given input. | ||
Raises: | ||
ValueError: If the input was provided as string but could not be interpreted as | ||
fraction. | ||
""" | ||
if isinstance(value, str): | ||
try: | ||
value = Fraction(value) | ||
except ValueError as err: | ||
raise ValueError( | ||
f"The provided input '{value}' could not be interpreted as a fraction." | ||
) from err | ||
return float(value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.