-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR introduces a first basic benchmarking structure, which includes: - Benchmarking classes that execute a custom-defined callable which makes use of parametrizable setting input. - Classes to Create and develop further settings, used in the callable. - Results containing the DataFrame and relevant metadata based on the execution. - One benchmark function to demonstrate the module's synergy.
- Loading branch information
Showing
16 changed files
with
366 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Benchmarking module for performance tracking.""" | ||
|
||
from benchmarks.definition import Benchmark | ||
from benchmarks.result import Result | ||
|
||
__all__ = [ | ||
"Result", | ||
"Benchmark", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Executes the benchmarking module.""" | ||
# Run this via 'python -m benchmarks' from the root directory. | ||
|
||
from benchmarks.domains import BENCHMARKS | ||
|
||
|
||
def main(): | ||
"""Run all benchmarks.""" | ||
for benchmark in BENCHMARKS: | ||
benchmark() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Benchmark task definitions.""" | ||
|
||
from benchmarks.definition.config import ( | ||
Benchmark, | ||
BenchmarkSettings, | ||
ConvergenceExperimentSettings, | ||
) | ||
|
||
__all__ = [ | ||
"ConvergenceExperimentSettings", | ||
"Benchmark", | ||
"BenchmarkSettings", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Benchmark configurations.""" | ||
|
||
import time | ||
from abc import ABC | ||
from collections.abc import Callable | ||
from datetime import datetime, timedelta, timezone | ||
from typing import Any, Generic, TypeVar | ||
|
||
from attrs import define, field | ||
from attrs.validators import instance_of | ||
from pandas import DataFrame | ||
|
||
from baybe.serialization.mixin import SerialMixin | ||
from baybe.utils.random import temporary_seed | ||
from benchmarks.result import Result, ResultMetadata | ||
|
||
|
||
@define(frozen=True) | ||
class BenchmarkSettings(SerialMixin, ABC): | ||
"""Benchmark configuration for recommender analyses.""" | ||
|
||
random_seed: int = field(validator=instance_of(int), kw_only=True, default=1337) | ||
"""The random seed for reproducibility.""" | ||
|
||
|
||
BenchmarkSettingsType = TypeVar("BenchmarkSettingsType", bound=BenchmarkSettings) | ||
|
||
|
||
@define(frozen=True) | ||
class ConvergenceExperimentSettings(BenchmarkSettings): | ||
"""Benchmark configuration for recommender convergence analyses.""" | ||
|
||
batch_size: int = field(validator=instance_of(int)) | ||
"""The recommendation batch size.""" | ||
|
||
n_doe_iterations: int = field(validator=instance_of(int)) | ||
"""The number of Design of Experiment iterations.""" | ||
|
||
n_mc_iterations: int = field(validator=instance_of(int)) | ||
"""The number of Monte Carlo iterations.""" | ||
|
||
|
||
@define(frozen=True) | ||
class Benchmark(Generic[BenchmarkSettingsType]): | ||
"""The base class for a benchmark executable.""" | ||
|
||
settings: BenchmarkSettingsType = field() | ||
"""The benchmark configuration.""" | ||
|
||
function: Callable[[BenchmarkSettingsType], DataFrame] = field() | ||
"""The callable which contains the benchmarking logic.""" | ||
|
||
name: str = field(init=False) | ||
"""The name of the benchmark.""" | ||
|
||
best_possible_result: float | None = field(default=None) | ||
"""The best possible result which can be achieved in the optimization process.""" | ||
|
||
optimal_function_inputs: list[dict[str, Any]] | None = field(default=None) | ||
"""An input that creates the best_possible_result.""" | ||
|
||
@property | ||
def description(self) -> str: | ||
"""The description of the benchmark function.""" | ||
if self.function.__doc__ is not None: | ||
return self.function.__doc__ | ||
return "No description available." | ||
|
||
@name.default | ||
def _default_name(self): | ||
"""Return the name of the benchmark function.""" | ||
return self.function.__name__ | ||
|
||
def __call__(self) -> Result: | ||
"""Execute the benchmark and return the result.""" | ||
start_datetime = datetime.now(timezone.utc) | ||
|
||
with temporary_seed(self.settings.random_seed): | ||
start_sec = time.perf_counter() | ||
result = self.function(self.settings) | ||
stop_sec = time.perf_counter() | ||
|
||
duration = timedelta(seconds=stop_sec - start_sec) | ||
|
||
metadata = ResultMetadata( | ||
start_datetime=start_datetime, | ||
duration=duration, | ||
) | ||
|
||
return Result(self.name, result, metadata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Benchmark domains.""" | ||
|
||
from benchmarks.definition.config import Benchmark | ||
from benchmarks.domains.synthetic_2C1D_1C import synthetic_2C1D_1C_benchmark | ||
|
||
BENCHMARKS: list[Benchmark] = [ | ||
synthetic_2C1D_1C_benchmark, | ||
] | ||
|
||
__all__ = ["BENCHMARKS"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
"""Synthetic function with two continuous and one discrete input.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
from numpy import pi, sin, sqrt | ||
from pandas import DataFrame | ||
|
||
from baybe.campaign import Campaign | ||
from baybe.parameters import NumericalContinuousParameter, NumericalDiscreteParameter | ||
from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender | ||
from baybe.searchspace import SearchSpace | ||
from baybe.simulation import simulate_scenarios | ||
from baybe.targets import NumericalTarget, TargetMode | ||
from benchmarks.definition import ( | ||
Benchmark, | ||
ConvergenceExperimentSettings, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from mpl_toolkits.mplot3d import Axes3D | ||
|
||
|
||
def _lookup(z: np.ndarray, x: np.ndarray, y: np.ndarray) -> np.ndarray: | ||
"""Lookup that is used internally in the callable for the benchmark.""" | ||
try: | ||
assert np.all(-2 * pi <= x) and np.all(x <= 2 * pi) | ||
assert np.all(-2 * pi <= y) and np.all(y <= 2 * pi) | ||
assert np.all(np.isin(z, [1, 2, 3, 4])) | ||
except AssertionError: | ||
raise ValueError("Inputs are not in the valid ranges.") | ||
|
||
return ( | ||
(z == 1) * sin(x) * (1 + sin(y)) | ||
+ (z == 2) * (x * sin(0.9 * x) + sin(x) * sin(y)) | ||
+ (z == 3) * (sqrt(x + 8) * sin(x) + sin(x) * sin(y)) | ||
+ (z == 4) * (x * sin(1.666 * sqrt(x + 8)) + sin(x) * sin(y)) | ||
) | ||
|
||
|
||
def synthetic_2C1D_1C(settings: ConvergenceExperimentSettings) -> DataFrame: | ||
"""Hybrid synthetic test function. | ||
Inputs: | ||
z discrete {1,2,3,4} | ||
x continuous [-2*pi, 2*pi] | ||
y continuous [-2*pi, 2*pi] | ||
Output: continuous | ||
Objective: Maximization | ||
Optimal Inputs: | ||
{x: 1.610, y: 1.571, z: 3} | ||
{x: 1.610, y: -4.712, z: 3} | ||
Optimal Output: 4.09685 | ||
""" | ||
parameters = [ | ||
NumericalContinuousParameter("x", (-2 * pi, 2 * pi)), | ||
NumericalContinuousParameter("y", (-2 * pi, 2 * pi)), | ||
NumericalDiscreteParameter("z", (1, 2, 3, 4)), | ||
] | ||
|
||
objective = NumericalTarget(name="target", mode=TargetMode.MAX).to_objective() | ||
search_space = SearchSpace.from_product(parameters=parameters) | ||
|
||
scenarios: dict[str, Campaign] = { | ||
"Random Recommender": Campaign( | ||
searchspace=search_space, | ||
recommender=RandomRecommender(), | ||
objective=objective, | ||
), | ||
"Default Recommender": Campaign( | ||
searchspace=search_space, | ||
objective=objective, | ||
), | ||
} | ||
|
||
return simulate_scenarios( | ||
scenarios, | ||
_lookup, | ||
batch_size=settings.batch_size, | ||
n_doe_iterations=settings.n_doe_iterations, | ||
n_mc_iterations=settings.n_mc_iterations, | ||
impute_mode="error", | ||
) | ||
|
||
|
||
benchmark_config = ConvergenceExperimentSettings( | ||
batch_size=5, | ||
n_doe_iterations=30, | ||
n_mc_iterations=50, | ||
) | ||
|
||
synthetic_2C1D_1C_benchmark = Benchmark( | ||
function=synthetic_2C1D_1C, | ||
best_possible_result=4.09685, | ||
settings=benchmark_config, | ||
optimal_function_inputs=[ | ||
{"x": 1.610, "y": 1.571, "z": 3}, | ||
{"x": 1.610, "y": -4.712, "z": 3}, | ||
], | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Visualize the domain | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
X = np.linspace(-2 * pi, 2 * pi) | ||
Y = np.linspace(-2 * pi, 2 * pi) | ||
Z = [1, 2, 3, 4] | ||
|
||
x_mesh, y_mesh = np.meshgrid(X, Y) | ||
|
||
fig = plt.figure(figsize=(10, 10)) | ||
for i, z in enumerate(Z): | ||
ax: Axes3D = fig.add_subplot(2, 2, i + 1, projection="3d") | ||
t_mesh = _lookup(np.asarray(z), x_mesh, y_mesh) | ||
ax.plot_surface(x_mesh, y_mesh, t_mesh) | ||
plt.title(f"{z=}") | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Benchmark results.""" | ||
|
||
from benchmarks.result.metadata import ResultMetadata | ||
from benchmarks.result.result import Result | ||
|
||
__all__ = ["Result", "ResultMetadata"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Benchmark result metadata.""" | ||
|
||
from datetime import datetime, timedelta | ||
|
||
import git | ||
from attrs import define, field | ||
from attrs.validators import instance_of | ||
from cattrs.gen import make_dict_unstructure_fn | ||
|
||
from baybe.serialization.core import converter | ||
from baybe.serialization.mixin import SerialMixin | ||
|
||
|
||
@define(frozen=True) | ||
class ResultMetadata(SerialMixin): | ||
"""The metadata of a benchmark result.""" | ||
|
||
start_datetime: datetime = field(validator=instance_of(datetime)) | ||
"""The start datetime of the benchmark.""" | ||
|
||
duration: timedelta = field(validator=instance_of(timedelta)) | ||
"""The time it took to complete the benchmark.""" | ||
|
||
commit_hash: str = field(validator=instance_of(str), init=False) | ||
"""The commit hash of the used BayBE code.""" | ||
|
||
latest_baybe_tag: str = field(validator=instance_of(str), init=False) | ||
"""The latest BayBE tag reachable in the ancestor commit history.""" | ||
|
||
@commit_hash.default | ||
def _default_commit_hash(self) -> str: | ||
"""Extract the git commit hash.""" | ||
repo = git.Repo(search_parent_directories=True) | ||
sha = repo.head.object.hexsha | ||
return sha | ||
|
||
@latest_baybe_tag.default | ||
def _default_latest_baybe_tag(self) -> str: | ||
"""Extract the latest reachable BayBE tag.""" | ||
repo = git.Repo(search_parent_directories=True) | ||
latest_tag = repo.git.describe(tags=True, abbrev=0) | ||
return latest_tag | ||
|
||
|
||
# Register un-/structure hooks | ||
converter.register_unstructure_hook( | ||
ResultMetadata, | ||
make_dict_unstructure_fn( | ||
ResultMetadata, converter, _cattrs_include_init_false=True | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""Basic result classes for benchmarking.""" | ||
|
||
from __future__ import annotations | ||
|
||
from attrs import define, field | ||
from attrs.validators import instance_of | ||
from pandas import DataFrame | ||
|
||
from baybe.serialization.mixin import SerialMixin | ||
from benchmarks.result import ResultMetadata | ||
|
||
|
||
@define(frozen=True) | ||
class Result(SerialMixin): | ||
"""A single result of the benchmarking.""" | ||
|
||
benchmark_identifier: str = field(validator=instance_of(str)) | ||
"""The identifier of the benchmark that produced the result.""" | ||
|
||
data: DataFrame = field(validator=instance_of(DataFrame)) | ||
"""The result of the benchmarked callable.""" | ||
|
||
metadata: ResultMetadata = field(validator=instance_of(ResultMetadata)) | ||
"""The metadata associated with the benchmark result.""" |
Oops, something went wrong.