Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Acquisition metadata (different approach) #767

Draft
wants to merge 11 commits into
base: develop
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support metadata in AssychronousOptimization too
Uri Granta committed Jul 12, 2023
commit 879ff47218c84c1a15bee54e9ed5d5c038e25f61
56 changes: 42 additions & 14 deletions trieste/acquisition/rule.py
Original file line number Diff line number Diff line change
@@ -119,7 +119,6 @@ def acquire(
:param search_space: The local acquisition search space for *this step*.
:param models: The model for each tag.
:param datasets: The known observer query points and observations for each tag (optional).
:param metadata: Any additional acquisition metadata (optional).
:return: A value of type `T_co`.
"""

@@ -529,7 +528,9 @@ def __init__(
self: "AsynchronousOptimization[SearchSpaceType, ProbabilisticModelType]",
builder: (
AcquisitionFunctionBuilder[ProbabilisticModelType]
| MetadataAcquisitionFunctionBuilder[ProbabilisticModelType]
| SingleModelAcquisitionBuilder[ProbabilisticModelType]
| SingleModelMetadataAcquisitionBuilder[ProbabilisticModelType]
),
optimizer: AcquisitionOptimizer[SearchSpaceType] | None = None,
num_query_points: int = 1,
@@ -540,7 +541,9 @@ def __init__(
self,
builder: Optional[
AcquisitionFunctionBuilder[ProbabilisticModelType]
| MetadataAcquisitionFunctionBuilder[ProbabilisticModelType]
| SingleModelAcquisitionBuilder[ProbabilisticModelType]
| SingleModelMetadataAcquisitionBuilder[ProbabilisticModelType]
] = None,
optimizer: AcquisitionOptimizer[SearchSpaceType] | None = None,
num_query_points: int = 1,
@@ -568,15 +571,20 @@ def __init__(
if optimizer is None:
optimizer = automatic_optimizer_selector

if isinstance(builder, SingleModelAcquisitionBuilder):
if isinstance(
builder, (SingleModelAcquisitionBuilder, SingleModelMetadataAcquisitionBuilder)
):
builder = builder.using(OBJECTIVE)

# even though we are only using batch acquisition functions
# there is no need to batchify_joint the optimizer if our batch size is 1
if num_query_points > 1:
optimizer = batchify_joint(optimizer, num_query_points)

self._builder: AcquisitionFunctionBuilder[ProbabilisticModelType] = builder
self._builder: Union[
AcquisitionFunctionBuilder[ProbabilisticModelType],
MetadataAcquisitionFunctionBuilder[ProbabilisticModelType],
] = builder
self._optimizer = optimizer
self._acquisition_function: Optional[AcquisitionFunction] = None

@@ -586,13 +594,20 @@ def __repr__(self) -> str:
{self._builder!r},
{self._optimizer!r})"""

# TODO: support metadata

def acquire(
self,
search_space: SearchSpaceType,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
) -> types.State[AsynchronousRuleState | None, TensorType]:
return self.acquire_with_metadata(search_space, models, datasets=datasets)

def acquire_with_metadata(
self,
search_space: SearchSpaceType,
models: Mapping[Tag, ProbabilisticModelType],
datasets: Optional[Mapping[Tag, Dataset]] = None,
metadata: Optional[Mapping[str, Any]] = None,
) -> types.State[AsynchronousRuleState | None, TensorType]:
"""
Constructs a function that, given ``AsynchronousRuleState``,
@@ -610,6 +625,9 @@ def acquire(
:param search_space: The local acquisition search space for *this step*.
:param models: The model of the known data. Uses the single key `OBJECTIVE`.
:param datasets: The known observer query points and observations.
:param metadata: Any additional acquisition metadata. This is passed to any
:class:`~trieste.acquisition.MetadataAcquisitionFunctionBuilder` builder, and
ignored otherwise.
:return: A function that constructs the next acquisition state and the recommended query
points from the previous acquisition state.
"""
@@ -623,16 +641,26 @@ def acquire(
)

if self._acquisition_function is None:
self._acquisition_function = self._builder.prepare_acquisition_function(
models,
datasets=datasets,
)
if isinstance(self._builder, MetadataAcquisitionFunctionBuilder):
self._acquisition_function = self._builder.prepare_acquisition_function(
models, datasets=datasets, metadata=metadata
)
else:
self._acquisition_function = self._builder.prepare_acquisition_function(
models,
datasets=datasets,
)
else:
self._acquisition_function = self._builder.update_acquisition_function(
self._acquisition_function,
models,
datasets=datasets,
)
if isinstance(self._builder, MetadataAcquisitionFunctionBuilder):
self._acquisition_function = self._builder.update_acquisition_function(
self._acquisition_function, models, datasets=datasets, metadata=metadata
)
else:
self._acquisition_function = self._builder.update_acquisition_function(
self._acquisition_function,
models,
datasets=datasets,
)

def state_func(
state: AsynchronousRuleState | None,