Skip to content

Commit

Permalink
update jax version (#1888)
Browse files Browse the repository at this point in the history
* update jax version

* remove keras 2 from gpu tests

* update comment
  • Loading branch information
divyashreepathihalli authored Sep 27, 2024
1 parent 876449e commit d7b880d
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 79 deletions.
11 changes: 2 additions & 9 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,7 @@ nvcc --version
cd "src/github/keras-hub"
pip install -U pip setuptools psutil

if [ "${KERAS2:-0}" == "1" ]
then
echo "Keras2 detected."
pip install -r requirements-common.txt --progress-bar off --timeout 1000
pip install tensorflow-text==2.15 tensorflow[and-cuda]~=2.15 keras-core \
--timeout 1000

elif [ "$KERAS_BACKEND" == "tensorflow" ]
if [ "$KERAS_BACKEND" == "tensorflow" ]
then
echo "TensorFlow backend detected."
pip install -r requirements-tensorflow-cuda.txt --progress-bar off \
Expand Down Expand Up @@ -67,4 +60,4 @@ then
else
pytest keras_hub --check_gpu --run_large \
--cov=keras-hub
fi
fi
34 changes: 0 additions & 34 deletions .kokoro/github/ubuntu/gpu/keras2/continuous.cfg

This file was deleted.

34 changes: 0 additions & 34 deletions .kokoro/github/ubuntu/gpu/keras2/presubmit.cfg

This file was deleted.

4 changes: 2 additions & 2 deletions requirements-jax-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ torch>=2.1.0
torchvision>=0.16.0

# Jax with cuda support.
# TODO: 0.4.24 has an updated Cuda version breaks Jax CI.
# Keep same version as Keras repo.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]==0.4.23
jax[cuda12_pip]==0.4.28

-r requirements-common.txt

0 comments on commit d7b880d

Please sign in to comment.