Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Multi GPU(Mirrored Strategy) training with XLA and AMP(Mixed Precision) #20763

Open
keshusharmamrt opened this issue Jan 15, 2025 · 4 comments
Assignees
Labels

Comments

@keshusharmamrt
Copy link

keshusharmamrt commented Jan 15, 2025

Hi I've encountered some issues while trying to perform multi-GPU training with XLA (Accelerated Linear Algebra), and AMP (Automatic Mixed Precision).
I'm reaching out to understand if it's possible to use multi-GPU training with XLA and AMP together.
If so, I'd like guidance on which versions of tensorflow and keras should I use or how to modify my code to make this work.

Background:

In earlier versions of TensorFlow (prior to 2.11), we were able to successfully train models using multiple GPUs with both XLA and AMP enabled. However, with versions beyond tensorflow 2.11 versions, I've not been able to run Training with multi gpu+xla+amp.

Issues Encountered with Different Versions:
I use tf-keras=2.15 for all these tests,

1. tensorflow=2.17.1/2.16.2 and keras=3.8.0:

Error Message
RuntimeError: Exception encountered when calling Cond.call() merge_call called while defining a new graph or a tf.function. This can often happen if the function fn passed to strategy.run() contains a nested @tf.function, and the nested @tf.function contains a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients), or if the function fn uses a control flow statement which contains a synchronization point in the body. Such behaviors are not yet supported. Instead, please avoid nested tf.functions or control flow statements that may potentially cross a synchronization boundary, for example, wrap the fn passed to strategy.run or the entire strategy.run inside a tf.function or move the control flow out of fn. If you are subclassing a tf.keras.Model, please avoid decorating overridden methods test_step and train_step in tf.function

2. tensorflow=2.17.1/2.16.2 and keras=3.5.0:

Issue: Training gets stuck after few epochs and not progress.

3. tensorflow=2.17.1/2.16.2 and keras=3.0.5

Error Message:
UnimplementedError: We failed to lift variable creations out of this tf.function, so this tf.function cannot be run on XLA. A possible workaround is to move variable creation outside of the XLA compiled function.

4. tensorflow=2.18.0 Also gives similar error with keras versions 3.0,3.5 and 3.6

5. Using TF_USE_LEGACY_KERAS=1
Training gets stuck after some time similarly as with keras>3. I have tried this with various tensorflow versions but got same training stuck.

Code Snippet:
Here's a simplified version of the code I'm using. This example is adapted from the Keras documentation on distributed training:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras

def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.
    keras.mixed_precision.set_global_policy("mixed_float16")
    inputs = keras.Input(shape=(784,))
    x = keras.layers.Dense(256, activation="relu")(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-5),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
        jit_compile=True,
    )
    return model


def get_dataset():
    batch_size = 32
    num_val_samples = 10000

    # Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32") / 255
    x_test = x_test.reshape(-1, 784).astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    # Reserve num_val_samples samples for validation
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    return (
        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )


# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))

# Open a strategy scope.
with strategy.scope():
    # Everything that creates variables should be under the strategy scope.
    # In general this is only model construction & `compile()`.
    model = get_compiled_model()

    # Train the model on all available devices.
    train_dataset, val_dataset, test_dataset = get_dataset()
model.fit(train_dataset, epochs=20, validation_data=val_dataset)

# Test the model on all available devices.
model.evaluate(test_dataset)

Can Someone Suggest what changes I should make in code or which version of keras and tensorflow to use to make training with multi gpu+xla+amp work?
or is it not possible train using multi GPU+XLA+AMP?

@dhantule
Copy link
Contributor

dhantule commented Jan 20, 2025

Hi @keshusharmamrt, thanks for reporting this.

I ran your code with Keras 3.8.0 and Tensorflow 2.17.1, I got the same error. However I did not face any issues with Keras 3.5.0 and Tensorflow 2.17.1, also with Keras 3.0.5 and Tensorflow 2.17.1. Attaching gist for reference.

@keshusharmamrt
Copy link
Author

keshusharmamrt commented Jan 20, 2025

hi @dhantule Thanks for your reply.
It seems you are using Mirrored Strategy with only single GPU (for keras 3.5.0 and keras 3.0.5) in google Colab. 🤔
Image.
I modified Colab code to use 2 devices
strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])
and I got following conclusions:

  1. keras 3.5.0 :- training stuck forever.
  2. keras 3.0.5 :- I got following error:

Image

gist.

Also I have tried to run this code with our own clusters where we have two gpus I see similar stuff there too.(I tried this with this official image of tensorflow(tensorflow/tensorflow:2.17.0-gpu-jupyter) and manually installed tensorflow==2.17.1 and keras==3.5.0 and 3.0.5 )

  1. Screenshot from there keras==3.5.0:

Image
Image

  1. With keras==3.0.5

Image

Interestingly It works when we have single GPU 😄 on Colab as well as on our cluster.

Also one more thing I tried to use "JAX" backend for keras and it seems to work by employing this tutorial.

I was wondering If its possible to use Multi GPU with XLA and AMP with tensorflow backend of keras??

@dhantule dhantule added the keras-team-review-pending Pending review by a Keras team member. label Jan 22, 2025
@SamanehSaadat SamanehSaadat removed the keras-team-review-pending Pending review by a Keras team member. label Jan 23, 2025
@sampathweb
Copy link
Collaborator

Issues Encountered with Different Versions:
I use tf-keras=2.15 for all these tests,

  1. tensorflow=2.17.1/2.16.2 and keras=3.8.0:

Please use same version of tf-keras and tensorflow when using Keras 2. For example, tf-keras~=2.17.0 when using tensorflow~=2.17.0.

Then you could use it by

import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"

import tensorflow as tf
from tensorflow import keras  # To use Keras 2

Let me know if this works for you.

@keshusharmamrt
Copy link
Author

keshusharmamrt commented Jan 24, 2025

hi @sampathweb thanks for the reply.
I tried to use the approach you suggested
tf-keras~=2.17.0 and tensorflow~=2.17.0
I still see issue with it as training hangs for infinite time using this and doesn't progress. Same happened when I tried to use Training on our clusters.
you can find code here gist.
Not sure if I am doing something wrong or its not possible to use xla+amp+multi gpu training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants