We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The docs state:
numerics are identical across backends [...] [...] up to 1e-7 precision in float32, per function execution
However, this minimal example does not confirm this:
import os.path import numpy as np from keras import layers, Model from keras.src.saving import load_model np.random.seed(0) data = np.random.rand(1, 256, 256, 1024) if os.path.isfile("model.keras"): model = load_model("model.keras") else: inputs = layers.Input(shape=(256, 256, 1024)) outputs = layers.Conv2D(1024, kernel_size=(4, 7), padding="same", dilation_rate=(3, 2))(inputs) model = Model(inputs=[inputs], outputs=outputs) model.save("model.keras") print(np.sum([data])) print(os.environ["KERAS_BACKEND"]) print(np.sum(np.array(model([data]))))
Output with tensorflow backend:
tensorflow
KERAS_BACKEND=tensorflow python main.py
33550919.07926151 tensorflow 58094.56
Output with jax backend:
jax
KERAS_BACKEND=jax python main.py
33550919.07926151 jax 58094.523
Versions used:
python -c "import keras; import tensorflow; import jax; print(keras.__version__); print(tensorflow.__version__); print(jax.__version__)"
3.8.0 2.18.0 0.5.0
The text was updated successfully, but these errors were encountered:
mehtamansi29
No branches or pull requests
The docs state:
However, this minimal example does not confirm this:
Output with
tensorflow
backend:Output with
jax
backend:Versions used:
python -c "import keras; import tensorflow; import jax; print(keras.__version__); print(tensorflow.__version__); print(jax.__version__)"
The text was updated successfully, but these errors were encountered: