Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Jul 18, 2023
1 parent 89697f4 commit 1b58f41
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
15 changes: 13 additions & 2 deletions smt/applications/ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from smt.sampling_methods import LHS


class Evaluator(object):
"""
An interface for evaluation of a function at x points (nsamples of dimension nx).
Expand Down Expand Up @@ -264,11 +265,21 @@ def _setup_optimizer(self, fun):
self.design_space,
work_in_folded_space=True,
)
self._sampling = self.mixint.build_sampling_method(LHS, criterion="ese",random_state=self.options['random_state'], new_sampler=True)
self._sampling = self.mixint.build_sampling_method(
LHS,
criterion="ese",
random_state=self.options["random_state"],
new_sampler=True,
)

else:
self.mixint = None
self._sampling = lambda n: self.design_space.sample_valid_x(n,criterion="ese",random_state=self.options['random_state'], new_sampler=True)[0]
self._sampling = lambda n: self.design_space.sample_valid_x(
n,
criterion="ese",
random_state=self.options["random_state"],
new_sampler=True,
)[0]
self.categorical_kernel = None

# Build DOE
Expand Down
12 changes: 7 additions & 5 deletions smt/utils/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,9 @@ class DesignSpace(BaseDesignSpace):
"""

def __init__(self, design_variables: Union[List[DesignVariable], list, np.ndarray], seed=None):
def __init__(
self, design_variables: Union[List[DesignVariable], list, np.ndarray], seed=None
):
self.sampler = None
self.new_sampler = True

Expand Down Expand Up @@ -748,12 +750,12 @@ def _sample_valid_x(self, n: int, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
x_limits_unfolded = self.get_unfolded_num_bounds()
if "random_state" in kwargs.keys():
self.seed = kwargs["random_state"]
if "new_sampler" in kwargs.keys() and kwargs["new_sampler"] :
if "new_sampler" in kwargs.keys() and kwargs["new_sampler"]:
kwargs.pop("new_sampler", None)
if self.new_sampler :
if self.new_sampler:
self.sampler = LHS(xlimits=x_limits_unfolded, **kwargs)
self.new_sampler=False
if self.sampler is None :
self.new_sampler = False
if self.sampler is None:
self.sampler = LHS(xlimits=x_limits_unfolded, **kwargs)
x = self.sampler(n)

Expand Down

0 comments on commit 1b58f41

Please sign in to comment.