Skip to content

Commit

Permalink
Merge branch 'SMTorg:master' into dev-castano
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastanoUQ authored Oct 21, 2024
2 parents 614e235 + ce7cbbb commit 5fe7142
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 102 deletions.
18 changes: 5 additions & 13 deletions smt/applications/tests/test_ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import numpy as np

import smt.design_space as ds
from smt.applications import EGO
from smt.applications.ego import Evaluator
from smt.applications.mixed_integer import (
Expand Down Expand Up @@ -1120,10 +1119,7 @@ def f_obj(X):
LHS, design_space, criterion="ese", random_state=random_state
)
Xt = sampling(n_doe)
if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl
self.assertAlmostEqual(np.sum(Xt), 24.811925491708156, delta=1e-4)
else:
self.assertAlmostEqual(np.sum(Xt), 28.568852027679586, delta=1e-4)
self.assertAlmostEqual(np.sum(Xt), 28.568852027679586, delta=1e-4)
Xt = np.array(
[
[0.37454012, 1.0],
Expand Down Expand Up @@ -1155,12 +1151,8 @@ def f_obj(X):
n_start=25,
)
x_opt, y_opt, dnk, x_data, y_data = ego.optimize(fun=f_obj)
if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl
self.assertAlmostEqual(np.sum(y_data), 8.846225704750577, delta=1e-4)
self.assertAlmostEqual(np.sum(x_data), 41.811925504901374, delta=1e-4)
else:
self.assertAlmostEqual(np.sum(y_data), 7.8471910288712, delta=1e-4)
self.assertAlmostEqual(np.sum(x_data), 34.81192549, delta=1e-4)
self.assertAlmostEqual(np.sum(y_data), 7.8471910288712, delta=1e-4)
self.assertAlmostEqual(np.sum(x_data), 34.81192549, delta=1e-4)

def test_ego_gek(self):
ego, fun = self.initialize_ego_gek()
Expand Down Expand Up @@ -1237,7 +1229,7 @@ def run_ego_example():

from smt.applications import EGO
from smt.surrogate_models import KRG
from smt.utils.design_space import DesignSpace
from smt.design_space import DesignSpace

def function_test_1d(x):
# function xsinx
Expand Down Expand Up @@ -1328,7 +1320,7 @@ def run_ego_mixed_integer_example():
from smt.applications import EGO
from smt.applications.mixed_integer import MixedIntegerContext
from smt.surrogate_models import KRG, MixIntKernelType
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down
20 changes: 11 additions & 9 deletions smt/applications/tests/test_mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
except ImportError:
NO_MATPLOTLIB = True

import smt.design_space as ds
from smt.design_space import (
HAS_CONFIG_SPACE,
DesignSpace,
CategoricalVariable,
FloatVariable,
Expand Down Expand Up @@ -464,9 +462,11 @@ def test_examples(self):
self.run_mixed_gower_example()
self.run_mixed_homo_gaussian_example()
self.run_mixed_homo_hyp_example()
if ds.HAS_CONFIG_SPACE:
self.run_mixed_cs_example()
self.run_hierarchical_design_space_example() # works only with config space impl
# FIXME: this test should belong to smt_design_space_ext
# but at the moment run_* code is used here to generate doc here in smt
# if HAS_DESIGN_SPACE_EXT:
# self.run_mixed_cs_example()
# self.run_hierarchical_design_space_example() # works only with config space impl

def run_mixed_integer_lhs_example(self):
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -917,8 +917,8 @@ def run_hierarchical_design_space_example(self):

self._sm = sm # to be ignored: just used for automated test

@unittest.skipIf(
not HAS_CONFIG_SPACE, "Hierarchy ConfigSpace dependency not installed"
@unittest.skip(
"Use smt design space extension capability, this test should be moved in smt_design_space_ext"
)
def test_hierarchical_design_space_example(self):
self.run_hierarchical_design_space_example()
Expand Down Expand Up @@ -950,8 +950,8 @@ def test_hierarchical_design_space_example(self):
> 1e-8
)

@unittest.skipIf(
not HAS_CONFIG_SPACE, "Hierarchy ConfigSpace dependency not installed"
@unittest.skip(
"Use smt design space extension capability, this test should be moved in smt_design_space_ext"
)
def test_hierarchical_design_space_example_all_categorical_decreed(self):
ds = DesignSpace(
Expand Down Expand Up @@ -2129,6 +2129,7 @@ def run_mixed_gower_example(self):
plt.tight_layout()
plt.show()

# FIXME: Used in SMT documentation but belongs to smt_design_space_ext domain
def run_mixed_cs_example(self):
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -2260,6 +2261,7 @@ def run_mixed_cs_example(self):
plt.tight_layout()
plt.show()

# FIXME: Used in SMT documentation but belongs to smt_design_space_ext domain
def run_mixed_homo_gaussian_example(self):
import matplotlib.pyplot as plt
import numpy as np
Expand Down
58 changes: 12 additions & 46 deletions smt/design_space/__init__.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,21 @@
import importlib

spec = importlib.util.find_spec("smt_design_space")
if spec:
HAS_DESIGN_SPACE_EXT = True
HAS_CONFIG_SPACE = True
HAS_ADSG = True
else:
HAS_DESIGN_SPACE_EXT = False
HAS_CONFIG_SPACE = False
HAS_ADSG = False


if HAS_DESIGN_SPACE_EXT:
from smt_design_space.design_space import (
CategoricalVariable,
DesignSpace,
BaseDesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
ensure_design_space,
)

else:
from smt.design_space.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
ensure_design_space,
BaseDesignSpace,
)

if HAS_DESIGN_SPACE_EXT:
from smt_design_space.design_space import DesignSpaceGraph
else:

class DesignSpaceGraph:
pass

from smt.design_space.design_space import (
CategoricalVariable,
BaseDesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
DesignSpace,
DesignVariable,
ensure_design_space,
)

__all__ = [
"HAS_DESIGN_SPACE_EXT",
"HAS_CONFIG_SPACE",
"HAS_ADSG",
"BaseDesignSpace",
"DesignSpace",
"FloatVariable",
"IntegerVariable",
"OrdinalVariable",
"CategoricalVariable",
"DesignSpace",
"DesignVariable",
"ensure_design_space",
]
5 changes: 0 additions & 5 deletions smt/design_space/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
from typing import List, Optional, Sequence, Tuple, Union


HAS_DESIGN_SPACE_EXT = False
HAS_CONFIG_SPACE = False
HAS_ADSG = False


class Configuration:
pass

Expand Down
42 changes: 13 additions & 29 deletions smt/design_space/tests/test_design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Author: Jasper Bussemaker <[email protected]>
"""

import contextlib
import itertools
import unittest

Expand All @@ -11,7 +10,6 @@
from smt.sampling_methods import LHS


import smt.design_space.design_space as ds
from smt.design_space.design_space import (
BaseDesignSpace,
CategoricalVariable,
Expand All @@ -22,16 +20,6 @@
)


@contextlib.contextmanager
def simulate_no_config_space(do_simulate=True):
if ds.HAS_CONFIG_SPACE and do_simulate:
ds.HAS_CONFIG_SPACE = False
yield
ds.HAS_CONFIG_SPACE = True
else:
yield


class Test(unittest.TestCase):
def test_design_variables(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -193,8 +181,6 @@ def test_base_design_space(self):

def test_create_design_space(self):
DesignSpace([FloatVariable(0, 1)])
with simulate_no_config_space():
DesignSpace([FloatVariable(0, 1)])

def test_design_space(self):
ds = DesignSpace(
Expand Down Expand Up @@ -421,22 +407,20 @@ def test_design_space_hierarchical(self):
assert len(seen_is_acting) == 2

def test_check_conditionally_acting_2(self):
for simulate_no_cs in [True, False]:
with simulate_no_config_space(simulate_no_cs):
ds = DesignSpace(
[
CategoricalVariable(["A", "B", "C"]), # x0
CategoricalVariable(["E", "F"]), # x1
IntegerVariable(0, 1), # x2
FloatVariable(0, 1), # x3
],
random_state=42,
)
ds.declare_decreed_var(
decreed_var=0, meta_var=1, meta_value="E"
) # Activate x3 if x0 == A
ds = DesignSpace(
[
CategoricalVariable(["A", "B", "C"]), # x0
CategoricalVariable(["E", "F"]), # x1
IntegerVariable(0, 1), # x2
FloatVariable(0, 1), # x3
],
random_state=42,
)
ds.declare_decreed_var(
decreed_var=0, meta_var=1, meta_value="E"
) # Activate x3 if x0 == A

ds.sample_valid_x(10, random_state=42)
ds.sample_valid_x(10, random_state=42)


if __name__ == "__main__":
Expand Down

0 comments on commit 5fe7142

Please sign in to comment.