Skip to content

Commit

Permalink
Move advanced design space handling in separate `smt-design-space-ext…
Browse files Browse the repository at this point in the history
…` module (#651)

* add notebook

* update design space utils

* ruff'

* add adsg 1.1.1

* fix setup
  • Loading branch information
Paul-Saves authored Oct 2, 2024
1 parent 0e8a348 commit 3646af2
Show file tree
Hide file tree
Showing 33 changed files with 475 additions and 925 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ omit =
*/tests/*
*/examples/*
*/__init__.py
*/setup.py
*/setup.py
2 changes: 1 addition & 1 deletion .github/workflows/tests_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -r requirements.txt numpy==1.26.4 ConfigSpace==0.6.1
pip install -r requirements.txt numpy==1.26.4 ConfigSpace==0.6.1 adsg-core==1.1.1 git+https://github.com/SMTorg/smt-design-space-ext
pip list
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ pytest-xdist # allows running parallel testing with pytest -n <num_workers>
pytest-cov # allows to get coverage report
ruff # format and lint code
jenn >= 1.0.2, <2.0
egobox ~= 0.20.0
egobox ~= 0.20.0
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
"smt.sampling_methods",
"smt.utils",
"smt.applications",
"smt.design_space",
"smt.kernels",
],
install_requires=[
"scikit-learn",
Expand All @@ -118,9 +120,6 @@
"numba": [ # pip install smt[numba]
"numba~=0.56.4",
],
"cs": [ # pip install smt[cs]
"ConfigSpace~=0.6.1",
],
"gpx": ["egobox~=0.22"], # pip install smt[gpx]
},
python_requires=">=3.9",
Expand Down
13 changes: 13 additions & 0 deletions smt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
__version__ = "2.7.0"

__all__ = [
"surrogate_models",
"kernels",
"design_space",
"applications",
"examples",
"sampling_methods",
"utils",
"tests",
"src",
"problems",
]
5 changes: 1 addition & 4 deletions smt/applications/ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
)
from smt.sampling_methods import LHS
from smt.surrogate_models import GEKPLS, GPX, KPLS, KPLSK, KRG, MGP
from smt.utils.design_space import (
BaseDesignSpace,
DesignSpace,
)
from smt.design_space import BaseDesignSpace, DesignSpace


class Evaluator(object):
Expand Down
2 changes: 1 addition & 1 deletion smt/applications/mfk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
MixIntKernelType,
compute_n_param,
)
from smt.utils.design_space import ensure_design_space
from smt.design_space import ensure_design_space
from smt.utils.kriging import (
componentwise_distance,
compute_X_cont,
Expand Down
3 changes: 2 additions & 1 deletion smt/applications/mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from smt.surrogate_models.krg_based import KrgBased, MixIntKernelType
from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils.checks import ensure_2d_array
from smt.utils.design_space import (

from smt.design_space import (
BaseDesignSpace,
CategoricalVariable,
ensure_design_space,
Expand Down
2 changes: 1 addition & 1 deletion smt/applications/tests/test_ego.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np

import smt.utils.design_space as ds
import smt.design_space as ds
from smt.applications import EGO
from smt.applications.ego import Evaluator
from smt.applications.mixed_integer import (
Expand Down
4 changes: 2 additions & 2 deletions smt/applications/tests/test_mfk_mfkpls_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
KRG,
MixIntKernelType,
)
from smt.utils.design_space import (
CategoricalVariable,
from smt.design_space import (
DesignSpace,
CategoricalVariable,
FloatVariable,
IntegerVariable,
)
Expand Down
43 changes: 22 additions & 21 deletions smt/applications/tests/test_mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
except ImportError:
NO_MATPLOTLIB = True

import smt.utils.design_space as ds
import smt.design_space as ds
from smt.design_space import (
HAS_CONFIG_SPACE,
DesignSpace,
CategoricalVariable,
FloatVariable,
IntegerVariable,
OrdinalVariable,
)

from smt.applications.mixed_integer import (
MixedIntegerContext,
MixedIntegerKrigingModel,
Expand All @@ -33,14 +42,6 @@
MixHrcKernelType,
MixIntKernelType,
)
from smt.utils.design_space import (
HAS_CONFIG_SPACE,
CategoricalVariable,
DesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
)


class TestMixedInteger(unittest.TestCase):
Expand Down Expand Up @@ -474,10 +475,10 @@ def run_mixed_integer_lhs_example(self):

from smt.applications.mixed_integer import MixedIntegerSamplingMethod
from smt.sampling_methods import LHS
from smt.utils.design_space import (
CategoricalVariable,
DesignSpace,
from smt.design_space import (
FloatVariable,
DesignSpace,
CategoricalVariable,
)

float_var = FloatVariable(0, 4)
Expand Down Expand Up @@ -507,7 +508,7 @@ def run_mixed_integer_qp_example(self):

from smt.applications.mixed_integer import MixedIntegerSurrogateModel
from smt.surrogate_models import QP
from smt.utils.design_space import DesignSpace, IntegerVariable
from smt.design_space import DesignSpace, IntegerVariable

xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
yt = np.array([0.0, 1.0, 1.5, 0.5, 1.0])
Expand Down Expand Up @@ -539,7 +540,7 @@ def run_mixed_integer_context_example(self):

from smt.applications.mixed_integer import MixedIntegerContext
from smt.surrogate_models import KRG
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down Expand Up @@ -747,7 +748,7 @@ def run_mixed_discrete_design_space_example(self):

from smt.applications.mixed_integer import MixedIntegerSamplingMethod
from smt.sampling_methods import LHS
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down Expand Up @@ -798,7 +799,7 @@ def run_hierarchical_design_space_example(self):
)
from smt.sampling_methods import LHS
from smt.surrogate_models import KRG, MixHrcKernelType, MixIntKernelType
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down Expand Up @@ -1099,7 +1100,7 @@ def run_hierarchical_variables_Goldstein(self):
)
from smt.sampling_methods import LHS
from smt.surrogate_models import KRG, MixHrcKernelType, MixIntKernelType
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down Expand Up @@ -2005,7 +2006,7 @@ def run_mixed_gower_example(self):
MixedIntegerKrigingModel,
)
from smt.surrogate_models import KRG, MixIntKernelType
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down Expand Up @@ -2136,7 +2137,7 @@ def run_mixed_cs_example(self):
MixedIntegerKrigingModel,
)
from smt.surrogate_models import KRG, MixIntKernelType
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down Expand Up @@ -2265,7 +2266,7 @@ def run_mixed_homo_gaussian_example(self):

from smt.applications.mixed_integer import MixedIntegerKrigingModel
from smt.surrogate_models import KRG, MixIntKernelType
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down Expand Up @@ -2394,7 +2395,7 @@ def run_mixed_homo_hyp_example(self):

from smt.applications.mixed_integer import MixedIntegerKrigingModel
from smt.surrogate_models import KRG, MixIntKernelType
from smt.utils.design_space import (
from smt.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
Expand Down
55 changes: 55 additions & 0 deletions smt/design_space/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import importlib

spec = importlib.util.find_spec("smt_design_space")
if spec:
HAS_DESIGN_SPACE_EXT = True
HAS_CONFIG_SPACE = True
HAS_ADSG = True
else:
HAS_DESIGN_SPACE_EXT = False
HAS_CONFIG_SPACE = False
HAS_ADSG = False


if HAS_DESIGN_SPACE_EXT:
from smt_design_space.design_space import (
CategoricalVariable,
DesignSpace,
BaseDesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
ensure_design_space,
)

else:
from smt.design_space.design_space import (
CategoricalVariable,
DesignSpace,
FloatVariable,
IntegerVariable,
OrdinalVariable,
ensure_design_space,
BaseDesignSpace,
)

if HAS_DESIGN_SPACE_EXT:
from smt_design_space.design_space import DesignSpaceGraph
else:

class DesignSpaceGraph:
pass


__all__ = [
"HAS_DESIGN_SPACE_EXT",
"HAS_CONFIG_SPACE",
"HAS_ADSG",
"BaseDesignSpace",
"DesignSpace",
"FloatVariable",
"IntegerVariable",
"OrdinalVariable",
"CategoricalVariable",
"ensure_design_space",
]
Loading

0 comments on commit 3646af2

Please sign in to comment.