From 712b3362b4ad19141ac6db7eeb57fbc375140d84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Lafage?= Date: Fri, 18 Oct 2024 16:18:47 +0200 Subject: [PATCH 1/3] Refactor `smt.design_space` (#665) * smt use smt_design_space_ext design space impl if installed * other base classes belongs to smt * only need HAS_DESIGN_SPACE_EXT * Remove constants managed in __init__ * Cleanup tests --- smt/applications/tests/test_ego.py | 4 +- smt/applications/tests/test_mixed_integer.py | 9 ++-- smt/design_space/__init__.py | 54 ++++++-------------- smt/design_space/design_space.py | 5 -- smt/design_space/tests/test_design_space.py | 42 +++++---------- 5 files changed, 35 insertions(+), 79 deletions(-) diff --git a/smt/applications/tests/test_ego.py b/smt/applications/tests/test_ego.py index d34133d39..d14fd9bd8 100644 --- a/smt/applications/tests/test_ego.py +++ b/smt/applications/tests/test_ego.py @@ -1120,7 +1120,7 @@ def f_obj(X): LHS, design_space, criterion="ese", random_state=random_state ) Xt = sampling(n_doe) - if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl + if ds.HAS_DESIGN_SPACE_EXT: # results differs wrt config_space impl self.assertAlmostEqual(np.sum(Xt), 24.811925491708156, delta=1e-4) else: self.assertAlmostEqual(np.sum(Xt), 28.568852027679586, delta=1e-4) @@ -1155,7 +1155,7 @@ def f_obj(X): n_start=25, ) x_opt, y_opt, dnk, x_data, y_data = ego.optimize(fun=f_obj) - if ds.HAS_CONFIG_SPACE: # results differs wrt config_space impl + if ds.HAS_DESIGN_SPACE_EXT: # results differs wrt config_space impl self.assertAlmostEqual(np.sum(y_data), 8.846225704750577, delta=1e-4) self.assertAlmostEqual(np.sum(x_data), 41.811925504901374, delta=1e-4) else: diff --git a/smt/applications/tests/test_mixed_integer.py b/smt/applications/tests/test_mixed_integer.py index 5cc20d162..68fc2da24 100644 --- a/smt/applications/tests/test_mixed_integer.py +++ b/smt/applications/tests/test_mixed_integer.py @@ -17,9 +17,8 @@ except ImportError: NO_MATPLOTLIB = True -import smt.design_space as ds from smt.design_space import ( - HAS_CONFIG_SPACE, + HAS_DESIGN_SPACE_EXT, DesignSpace, CategoricalVariable, FloatVariable, @@ -464,7 +463,7 @@ def test_examples(self): self.run_mixed_gower_example() self.run_mixed_homo_gaussian_example() self.run_mixed_homo_hyp_example() - if ds.HAS_CONFIG_SPACE: + if HAS_DESIGN_SPACE_EXT: self.run_mixed_cs_example() self.run_hierarchical_design_space_example() # works only with config space impl @@ -918,7 +917,7 @@ def run_hierarchical_design_space_example(self): self._sm = sm # to be ignored: just used for automated test @unittest.skipIf( - not HAS_CONFIG_SPACE, "Hierarchy ConfigSpace dependency not installed" + not HAS_DESIGN_SPACE_EXT, "Hierarchy ConfigSpace dependency not installed" ) def test_hierarchical_design_space_example(self): self.run_hierarchical_design_space_example() @@ -951,7 +950,7 @@ def test_hierarchical_design_space_example(self): ) @unittest.skipIf( - not HAS_CONFIG_SPACE, "Hierarchy ConfigSpace dependency not installed" + not HAS_DESIGN_SPACE_EXT, "Hierarchy ConfigSpace dependency not installed" ) def test_hierarchical_design_space_example_all_categorical_decreed(self): ds = DesignSpace( diff --git a/smt/design_space/__init__.py b/smt/design_space/__init__.py index fa0ef40c8..9ce1e192b 100644 --- a/smt/design_space/__init__.py +++ b/smt/design_space/__init__.py @@ -1,55 +1,33 @@ -import importlib - -spec = importlib.util.find_spec("smt_design_space") -if spec: - HAS_DESIGN_SPACE_EXT = True - HAS_CONFIG_SPACE = True - HAS_ADSG = True -else: - HAS_DESIGN_SPACE_EXT = False - HAS_CONFIG_SPACE = False - HAS_ADSG = False - - -if HAS_DESIGN_SPACE_EXT: - from smt_design_space.design_space import ( - CategoricalVariable, +from smt.design_space.design_space import ( + CategoricalVariable, + BaseDesignSpace, + FloatVariable, + IntegerVariable, + OrdinalVariable, +) + +try: + from smt_design_space_ext import ( DesignSpace, - BaseDesignSpace, - FloatVariable, - IntegerVariable, - OrdinalVariable, ensure_design_space, ) -else: - from smt.design_space.design_space import ( - CategoricalVariable, + HAS_DESIGN_SPACE_EXT = True +except ImportError: + from .design_space import ( DesignSpace, - FloatVariable, - IntegerVariable, - OrdinalVariable, ensure_design_space, - BaseDesignSpace, ) -if HAS_DESIGN_SPACE_EXT: - from smt_design_space.design_space import DesignSpaceGraph -else: - - class DesignSpaceGraph: - pass - + HAS_DESIGN_SPACE_EXT = False __all__ = [ - "HAS_DESIGN_SPACE_EXT", - "HAS_CONFIG_SPACE", - "HAS_ADSG", "BaseDesignSpace", - "DesignSpace", "FloatVariable", "IntegerVariable", "OrdinalVariable", "CategoricalVariable", + "DesignSpace", "ensure_design_space", + "HAS_DESIGN_SPACE_EXT", ] diff --git a/smt/design_space/design_space.py b/smt/design_space/design_space.py index f2d3260b1..28b571a63 100644 --- a/smt/design_space/design_space.py +++ b/smt/design_space/design_space.py @@ -9,11 +9,6 @@ from typing import List, Optional, Sequence, Tuple, Union -HAS_DESIGN_SPACE_EXT = False -HAS_CONFIG_SPACE = False -HAS_ADSG = False - - class Configuration: pass diff --git a/smt/design_space/tests/test_design_space.py b/smt/design_space/tests/test_design_space.py index a977a12c2..b6ffab49e 100644 --- a/smt/design_space/tests/test_design_space.py +++ b/smt/design_space/tests/test_design_space.py @@ -2,7 +2,6 @@ Author: Jasper Bussemaker """ -import contextlib import itertools import unittest @@ -11,7 +10,6 @@ from smt.sampling_methods import LHS -import smt.design_space.design_space as ds from smt.design_space.design_space import ( BaseDesignSpace, CategoricalVariable, @@ -22,16 +20,6 @@ ) -@contextlib.contextmanager -def simulate_no_config_space(do_simulate=True): - if ds.HAS_CONFIG_SPACE and do_simulate: - ds.HAS_CONFIG_SPACE = False - yield - ds.HAS_CONFIG_SPACE = True - else: - yield - - class Test(unittest.TestCase): def test_design_variables(self): with self.assertRaises(ValueError): @@ -193,8 +181,6 @@ def test_base_design_space(self): def test_create_design_space(self): DesignSpace([FloatVariable(0, 1)]) - with simulate_no_config_space(): - DesignSpace([FloatVariable(0, 1)]) def test_design_space(self): ds = DesignSpace( @@ -421,22 +407,20 @@ def test_design_space_hierarchical(self): assert len(seen_is_acting) == 2 def test_check_conditionally_acting_2(self): - for simulate_no_cs in [True, False]: - with simulate_no_config_space(simulate_no_cs): - ds = DesignSpace( - [ - CategoricalVariable(["A", "B", "C"]), # x0 - CategoricalVariable(["E", "F"]), # x1 - IntegerVariable(0, 1), # x2 - FloatVariable(0, 1), # x3 - ], - random_state=42, - ) - ds.declare_decreed_var( - decreed_var=0, meta_var=1, meta_value="E" - ) # Activate x3 if x0 == A + ds = DesignSpace( + [ + CategoricalVariable(["A", "B", "C"]), # x0 + CategoricalVariable(["E", "F"]), # x1 + IntegerVariable(0, 1), # x2 + FloatVariable(0, 1), # x3 + ], + random_state=42, + ) + ds.declare_decreed_var( + decreed_var=0, meta_var=1, meta_value="E" + ) # Activate x3 if x0 == A - ds.sample_valid_x(10, random_state=42) + ds.sample_valid_x(10, random_state=42) if __name__ == "__main__": From 2758bc3ca79b7ad2726fc95b142b4ae713b9aa12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Lafage?= Date: Fri, 18 Oct 2024 18:06:57 +0200 Subject: [PATCH 2/3] Remove `smt_design_space_ext` dependency from `smt.design_space` (#666) * Remove smt_design_space_ext dependency * Comment skipped tests --- smt/applications/tests/test_ego.py | 18 +++++------------- smt/applications/tests/test_mixed_integer.py | 19 +++++++++++-------- smt/design_space/__init__.py | 17 ++--------------- 3 files changed, 18 insertions(+), 36 deletions(-) diff --git a/smt/applications/tests/test_ego.py b/smt/applications/tests/test_ego.py index d14fd9bd8..21bb01e5e 100644 --- a/smt/applications/tests/test_ego.py +++ b/smt/applications/tests/test_ego.py @@ -11,7 +11,6 @@ import numpy as np -import smt.design_space as ds from smt.applications import EGO from smt.applications.ego import Evaluator from smt.applications.mixed_integer import ( @@ -1120,10 +1119,7 @@ def f_obj(X): LHS, design_space, criterion="ese", random_state=random_state ) Xt = sampling(n_doe) - if ds.HAS_DESIGN_SPACE_EXT: # results differs wrt config_space impl - self.assertAlmostEqual(np.sum(Xt), 24.811925491708156, delta=1e-4) - else: - self.assertAlmostEqual(np.sum(Xt), 28.568852027679586, delta=1e-4) + self.assertAlmostEqual(np.sum(Xt), 28.568852027679586, delta=1e-4) Xt = np.array( [ [0.37454012, 1.0], @@ -1155,12 +1151,8 @@ def f_obj(X): n_start=25, ) x_opt, y_opt, dnk, x_data, y_data = ego.optimize(fun=f_obj) - if ds.HAS_DESIGN_SPACE_EXT: # results differs wrt config_space impl - self.assertAlmostEqual(np.sum(y_data), 8.846225704750577, delta=1e-4) - self.assertAlmostEqual(np.sum(x_data), 41.811925504901374, delta=1e-4) - else: - self.assertAlmostEqual(np.sum(y_data), 7.8471910288712, delta=1e-4) - self.assertAlmostEqual(np.sum(x_data), 34.81192549, delta=1e-4) + self.assertAlmostEqual(np.sum(y_data), 7.8471910288712, delta=1e-4) + self.assertAlmostEqual(np.sum(x_data), 34.81192549, delta=1e-4) def test_ego_gek(self): ego, fun = self.initialize_ego_gek() @@ -1237,7 +1229,7 @@ def run_ego_example(): from smt.applications import EGO from smt.surrogate_models import KRG - from smt.utils.design_space import DesignSpace + from smt.design_space import DesignSpace def function_test_1d(x): # function xsinx @@ -1328,7 +1320,7 @@ def run_ego_mixed_integer_example(): from smt.applications import EGO from smt.applications.mixed_integer import MixedIntegerContext from smt.surrogate_models import KRG, MixIntKernelType - from smt.utils.design_space import ( + from smt.design_space import ( CategoricalVariable, DesignSpace, FloatVariable, diff --git a/smt/applications/tests/test_mixed_integer.py b/smt/applications/tests/test_mixed_integer.py index 68fc2da24..a5475f58e 100644 --- a/smt/applications/tests/test_mixed_integer.py +++ b/smt/applications/tests/test_mixed_integer.py @@ -18,7 +18,6 @@ NO_MATPLOTLIB = True from smt.design_space import ( - HAS_DESIGN_SPACE_EXT, DesignSpace, CategoricalVariable, FloatVariable, @@ -463,9 +462,11 @@ def test_examples(self): self.run_mixed_gower_example() self.run_mixed_homo_gaussian_example() self.run_mixed_homo_hyp_example() - if HAS_DESIGN_SPACE_EXT: - self.run_mixed_cs_example() - self.run_hierarchical_design_space_example() # works only with config space impl + # FIXME: this test should belong to smt_design_space_ext + # but at the moment run_* code is used here to generate doc here in smt + # if HAS_DESIGN_SPACE_EXT: + # self.run_mixed_cs_example() + # self.run_hierarchical_design_space_example() # works only with config space impl def run_mixed_integer_lhs_example(self): import matplotlib.pyplot as plt @@ -916,8 +917,8 @@ def run_hierarchical_design_space_example(self): self._sm = sm # to be ignored: just used for automated test - @unittest.skipIf( - not HAS_DESIGN_SPACE_EXT, "Hierarchy ConfigSpace dependency not installed" + @unittest.skip( + "Use smt design space extension capability, this test should be moved in smt_design_space_ext" ) def test_hierarchical_design_space_example(self): self.run_hierarchical_design_space_example() @@ -949,8 +950,8 @@ def test_hierarchical_design_space_example(self): > 1e-8 ) - @unittest.skipIf( - not HAS_DESIGN_SPACE_EXT, "Hierarchy ConfigSpace dependency not installed" + @unittest.skip( + "Use smt design space extension capability, this test should be moved in smt_design_space_ext" ) def test_hierarchical_design_space_example_all_categorical_decreed(self): ds = DesignSpace( @@ -2128,6 +2129,7 @@ def run_mixed_gower_example(self): plt.tight_layout() plt.show() + # FIXME: Used in SMT documentation but belongs to smt_design_space_ext domain def run_mixed_cs_example(self): import matplotlib.pyplot as plt import numpy as np @@ -2259,6 +2261,7 @@ def run_mixed_cs_example(self): plt.tight_layout() plt.show() + # FIXME: Used in SMT documentation but belongs to smt_design_space_ext domain def run_mixed_homo_gaussian_example(self): import matplotlib.pyplot as plt import numpy as np diff --git a/smt/design_space/__init__.py b/smt/design_space/__init__.py index 9ce1e192b..77a3e8fa3 100644 --- a/smt/design_space/__init__.py +++ b/smt/design_space/__init__.py @@ -4,23 +4,10 @@ FloatVariable, IntegerVariable, OrdinalVariable, + DesignSpace, + ensure_design_space, ) -try: - from smt_design_space_ext import ( - DesignSpace, - ensure_design_space, - ) - - HAS_DESIGN_SPACE_EXT = True -except ImportError: - from .design_space import ( - DesignSpace, - ensure_design_space, - ) - - HAS_DESIGN_SPACE_EXT = False - __all__ = [ "BaseDesignSpace", "FloatVariable", From ce7cbbb05c24971372de80a07c312837c0431d45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Lafage?= Date: Fri, 18 Oct 2024 19:24:49 +0200 Subject: [PATCH 3/3] Expose DesignVariable, remove forgotten HAS_DESIGN_SPACE_EXT (#667) --- smt/design_space/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/smt/design_space/__init__.py b/smt/design_space/__init__.py index 77a3e8fa3..a88f687dc 100644 --- a/smt/design_space/__init__.py +++ b/smt/design_space/__init__.py @@ -5,6 +5,7 @@ IntegerVariable, OrdinalVariable, DesignSpace, + DesignVariable, ensure_design_space, ) @@ -15,6 +16,6 @@ "OrdinalVariable", "CategoricalVariable", "DesignSpace", + "DesignVariable", "ensure_design_space", - "HAS_DESIGN_SPACE_EXT", ]