Use new ML Build CUDA images

Note the CUDA jobs fail in https://github.com/jax-ml/jax/actions/runs/13444028636/job/37565647540 but in the same way as on HEAD.

PiperOrigin-RevId: 729571786
This commit is contained in:
Nitin Srinivasan 2025-02-21 10:15:38 -08:00 committed by jax authors
parent a073639124
commit 5089fb01b2
2 changed files with 3 additions and 3 deletions

View File

@ -47,7 +47,7 @@ jobs:
# Explicitly set the shell to bash
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}

View File

@ -58,8 +58,8 @@ jobs:
shell: bash
runs-on: ${{ inputs.runner }}
# TODO: Update to the generic ML ecosystem test containers when they are ready.
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') ||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }}
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
env: