diff --git a/README.md b/README.md
index ed324489..64586220 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@
Gurobi Machine Learning is an [open-source](https://gurobi-machinelearning.readthedocs.io/en/latest/meta/license.html) python package to formulate trained regression models in a [`gurobipy`](https://pypi.org/project/gurobipy/) model to be solved with the Gurobi solver.
-The package currently supports various [scikit-learn](https://scikit-learn.org/stable/) objects. It has limited support for the [Keras](https://keras.io/) API of [TensorFlow](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/) and [XGBoost](https://www.xgboost.ai). Only neural networks with ReLU activation can be used with Keras and PyTorch.
+The package currently supports various [scikit-learn](https://scikit-learn.org/stable/) objects. It has limited support for [Keras](https://keras.io/), [PyTorch](https://pytorch.org/) and [XGBoost](https://www.xgboost.ai). Only neural networks with ReLU activation can be used with Keras and PyTorch.
# Documentation
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 73bc5900..063d76d6 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -81,7 +81,7 @@ def get_versions(file: Path):
.. |PandasVersion| replace:: {dep_versions["pandas"]}
.. |TorchVersion| replace:: {dep_versions["torch"]}
.. |SklearnVersion| replace:: {dep_versions["scikit-learn"]}
-.. |TensorflowVersion| replace:: {dep_versions["tensorflow"]}
+.. |KerasVersion| replace:: {dep_versions["keras"]}
.. |XGBoostVersion| replace:: {dep_versions["xgboost"]}
.. |LightGBMVersion| replace:: {dep_versions["lightgbm"]}
.. |VariablesDimensionsWarn| replace:: {VARS_SHAPE}
@@ -162,7 +162,7 @@ def get_versions(file: Path):
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
autodoc_member_order = "groupwise"
-autodoc_mock_imports = ["torch", "tensorflow", "xgboost"]
+autodoc_mock_imports = ["torch", "keras", "tensorflow", "xgboost"]
html_css_files = [
"gurobi_ml.css",
]
diff --git a/docs/source/user/start.rst b/docs/source/user/start.rst
index 0c14e9ce..d3b62461 100644
--- a/docs/source/user/start.rst
+++ b/docs/source/user/start.rst
@@ -29,9 +29,8 @@ The package currently supports various `scikit-learn
`_ objects. It can also formulate
gradient boosting regression models from `XGboost `_
and `LightGBM `.
-Finally, it has limited support for the
-`Keras `_ API of `TensorFlow `_
-and `PyTorch `_. Only neural networks with ReLU activation
+Finally, it has limited support for
+`Keras `_. Only neural networks with ReLU activation
can be used with these two packages.
The package is actively developed and users are encouraged to :doc:`contact us
@@ -87,8 +86,8 @@ We encourage to install the package via pip (or add it to your
- |TorchVersion|
* - :pypi:`scikit-learn`
- |SklearnVersion|
- * - :pypi:`tensorflow`
- - |TensorflowVersion|
+ * - :pypi:`keras`
+ - |KerasVersion|
* - :pypi:`xgboost`
- |XGBoostVersion|
* - :pypi:`lightgbm`
@@ -96,7 +95,7 @@ We encourage to install the package via pip (or add it to your
Installing any of the machine learning packages is only required if the
predictor you want to insert uses them (i.e. to insert a Keras based predictor
- you need to have :pypi:`tensorflow` installed).
+ you need to have :pypi:`keras` installed).
Usage
diff --git a/notebooks/dev_requirements.txt b/notebooks/dev_requirements.txt
index 53967667..83a9f2c7 100644
--- a/notebooks/dev_requirements.txt
+++ b/notebooks/dev_requirements.txt
@@ -1,14 +1,9 @@
ipywidgets
matplotlib
-myst_nb
notebook
pandas
seaborn
-Sphinx
-sphinx-copybutton
-sphinx-pyproject
-sphinx-rtd-theme
-tensorflow
+keras
torch
skorch
../.
diff --git a/requirements.keras.txt b/requirements.keras.txt
index 15147c7a..17cd20dd 100644
--- a/requirements.keras.txt
+++ b/requirements.keras.txt
@@ -1 +1,2 @@
tensorflow==2.18.0
+keras==3.7.0
diff --git a/src/gurobi_ml/keras/keras.py b/src/gurobi_ml/keras/keras.py
index ef43092d..372de4e0 100644
--- a/src/gurobi_ml/keras/keras.py
+++ b/src/gurobi_ml/keras/keras.py
@@ -16,7 +16,7 @@
"""Module for formulating a Keras model into a :external+gurobi:py:class:`Model`."""
import numpy as np
-from tensorflow import keras
+import keras
from ..exceptions import NoModel, NoSolution
from ..modeling.neuralnet import BaseNNConstr
diff --git a/src/gurobi_ml/registered_predictors.py b/src/gurobi_ml/registered_predictors.py
index 30d9f2cd..50d02f52 100644
--- a/src/gurobi_ml/registered_predictors.py
+++ b/src/gurobi_ml/registered_predictors.py
@@ -90,8 +90,8 @@ def lightgbm_convertors():
def keras_convertors():
"""Collect known Keras objects that can be embedded and the conversion class."""
- if "tensorflow" in sys.modules:
- from tensorflow import keras # pylint: disable=import-outside-toplevel
+ if "keras" in sys.modules:
+ import keras # pylint: disable=import-outside-toplevel
from .keras import add_keras_constr # pylint: disable=import-outside-toplevel
diff --git a/tests/test_keras/test_keras_exceptions.py b/tests/test_keras/test_keras_exceptions.py
index 2eace2c2..b067532c 100644
--- a/tests/test_keras/test_keras_exceptions.py
+++ b/tests/test_keras/test_keras_exceptions.py
@@ -2,8 +2,7 @@
import gurobipy as gp
import numpy as np
-import tensorflow as tf
-from tensorflow import keras
+import keras
from gurobi_ml import add_predictor_constr
from gurobi_ml.exceptions import NoModel
@@ -22,9 +21,9 @@ def setUp(self) -> None:
def do_test(self, nn):
nn.compile(
- optimizer=tf.keras.optimizers.Adam(0.001),
- loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
- metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
+ optimizer=keras.optimizers.Adam(0.001),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
nn.fit(
@@ -49,12 +48,12 @@ def test_keras_bad_activation(self):
self.x_train = np.reshape(self.x_train, (-1, 28 * 28))
self.x_test = np.reshape(self.x_test, (-1, 28 * 28))
- nn = tf.keras.models.Sequential(
+ nn = keras.models.Sequential(
[
- tf.keras.layers.InputLayer((28 * 28,)),
- tf.keras.layers.Dense(50, activation="sigmoid"),
- tf.keras.layers.Dense(50, activation="relu"),
- tf.keras.layers.Dense(10),
+ keras.layers.InputLayer((28 * 28,)),
+ keras.layers.Dense(50, activation="sigmoid"),
+ keras.layers.Dense(50, activation="relu"),
+ keras.layers.Dense(10),
]
)
self.do_test(nn)
@@ -65,19 +64,19 @@ def test_keras_layers(self):
self.x_train = np.reshape(self.x_train, (-1, 28, 28, 1))
self.x_test = np.reshape(self.x_test, (-1, 28, 28, 1))
- nn = tf.keras.models.Sequential(
+ nn = keras.models.Sequential(
[
- tf.keras.layers.InputLayer((28, 28, 1)),
- tf.keras.layers.BatchNormalization(),
- tf.keras.layers.Conv2D(32, (3, 3), padding="same"),
- tf.keras.layers.ReLU(),
- tf.keras.layers.MaxPooling2D((2, 2)),
- tf.keras.layers.Conv2D(64, (3, 3), padding="same"),
- tf.keras.layers.ReLU(),
- tf.keras.layers.MaxPooling2D((2, 2)),
- tf.keras.layers.Flatten(),
- tf.keras.layers.Dense(100, activation="relu"),
- tf.keras.layers.Dense(10, activation="softmax"),
+ keras.layers.InputLayer((28, 28, 1)),
+ keras.layers.BatchNormalization(),
+ keras.layers.Conv2D(32, (3, 3), padding="same"),
+ keras.layers.ReLU(),
+ keras.layers.MaxPooling2D((2, 2)),
+ keras.layers.Conv2D(64, (3, 3), padding="same"),
+ keras.layers.ReLU(),
+ keras.layers.MaxPooling2D((2, 2)),
+ keras.layers.Flatten(),
+ keras.layers.Dense(100, activation="relu"),
+ keras.layers.Dense(10, activation="softmax"),
]
)
self.do_test(nn)
@@ -87,12 +86,12 @@ def do_relu_tests(self, **kwargs):
self.x_train = np.reshape(self.x_train, (-1, 28 * 28))
self.x_test = np.reshape(self.x_test, (-1, 28 * 28))
- nn = tf.keras.models.Sequential(
+ nn = keras.models.Sequential(
[
- tf.keras.layers.InputLayer((28 * 28,)),
- tf.keras.layers.Dense(50),
- tf.keras.layers.ReLU(**kwargs),
- tf.keras.layers.Dense(10),
+ keras.layers.InputLayer((28 * 28,)),
+ keras.layers.Dense(50),
+ keras.layers.ReLU(**kwargs),
+ keras.layers.Dense(10),
]
)
self.do_test(nn)
diff --git a/tests/test_keras/test_keras_formulations.py b/tests/test_keras/test_keras_formulations.py
index 0beeb30b..375318e9 100644
--- a/tests/test_keras/test_keras_formulations.py
+++ b/tests/test_keras/test_keras_formulations.py
@@ -1,6 +1,6 @@
import os
-import tensorflow as tf
+import keras
from joblib import load
from ..fixed_formulation import FixedRegressionModel
@@ -18,7 +18,7 @@ def test_diabetes_keras(self):
X = load(os.path.join(self.basedir, "examples_diabetes.joblib"))
filename = os.path.join(self.basedir, "diabetes.keras")
- regressor = tf.keras.models.load_model(filename)
+ regressor = keras.saving.load_model(filename)
onecase = {"predictor": regressor, "nonconvex": 0}
self.do_one_case(onecase, X, 5, "all")
self.do_one_case(onecase, X, 6, "pairs")
@@ -30,7 +30,7 @@ def test_diabetes_keras_alt(self):
os.path.dirname(__file__), "..", "predictors", "diabetes_v2.keras"
)
print(filename)
- regressor = tf.keras.models.load_model(filename)
+ regressor = keras.saving.load_model(filename)
onecase = {"predictor": regressor, "nonconvex": 0}
self.do_one_case(onecase, X, 5, "all")
self.do_one_case(onecase, X, 6, "pairs")