Skip to content

Commit

Permalink
Unify _sample_from_polytope methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AVHopp committed Nov 20, 2024
1 parent d3887fd commit c79d2f0
Showing 1 changed file with 2 additions and 26 deletions.
28 changes: 2 additions & 26 deletions baybe/searchspace/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,6 @@ def sample_uniform(self, batch_size: int = 1) -> pd.DataFrame:
if not self.is_constrained:
return self._sample_from_bounds(batch_size, self.comp_rep_bounds.values)

if self.has_interpoint_constraints:
return self._sample_from_polytope_with_interpoint_constraints(
batch_size, self.comp_rep_bounds.values
)

# If there are neither cardinality nor interpoint constraints, we sample
# directly from the polytope
if len(self.constraints_cardinality) == 0:
Expand All @@ -426,12 +421,12 @@ def _sample_from_bounds(self, batch_size: int, bounds: np.ndarray) -> pd.DataFra

return pd.DataFrame(points, columns=self.parameter_names)

def _sample_from_polytope_with_interpoint_constraints(
def _sample_from_polytope(
self,
batch_size: int,
bounds: np.ndarray,
) -> pd.DataFrame:
"""Draw uniform random samples from a polytope with interpoint constraints."""
"""Draw uniform random samples from a polytope."""
# If the space has interpoint constraints, we need to sample from a larger
# searchspace that models the batch size via additional dimension. This is
# necessary since `get_polytope_samples` cannot handle interpoint constraints,
Expand Down Expand Up @@ -494,25 +489,6 @@ def _sample_from_polytope_with_interpoint_constraints(
points = points.reshape(batch_size, points.shape[-1] // batch_size)
return pd.DataFrame(points, columns=self.parameter_names)

def _sample_from_polytope(
self, batch_size: int, bounds: np.ndarray
) -> pd.DataFrame:
"""Draw uniform random samples from a polytope."""
import torch
from botorch.utils.sampling import get_polytope_samples

points = get_polytope_samples(
n=batch_size,
bounds=torch.from_numpy(bounds),
equality_constraints=[
c.to_botorch(self.parameters) for c in self.constraints_lin_eq
],
inequality_constraints=[
c.to_botorch(self.parameters) for c in self.constraints_lin_ineq
],
)
return pd.DataFrame(points, columns=self.parameter_names)

def _sample_from_polytope_with_cardinality_constraints(
self, batch_size: int
) -> pd.DataFrame:
Expand Down

0 comments on commit c79d2f0

Please sign in to comment.