mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #27164 from MichaelHudgins:a4-testing
PiperOrigin-RevId: 737733904
This commit is contained in:
commit
b74b16f9b9
3
.github/workflows/pytest_cuda.yml
vendored
3
.github/workflows/pytest_cuda.yml
vendored
@ -54,7 +54,8 @@ jobs:
|
|||||||
runs-on: ${{ inputs.runner }}
|
runs-on: ${{ inputs.runner }}
|
||||||
# TODO: Update to the generic ML ecosystem test containers when they are ready.
|
# TODO: Update to the generic ML ecosystem test containers when they are ready.
|
||||||
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
|
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
|
||||||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
|
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
|
||||||
|
(contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
|
||||||
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
|
24
.github/workflows/wheel_tests_continuous.yml
vendored
24
.github/workflows/wheel_tests_continuous.yml
vendored
@ -110,18 +110,30 @@ jobs:
|
|||||||
fail-fast: false # don't cancel all jobs on failure
|
fail-fast: false # don't cancel all jobs on failure
|
||||||
matrix:
|
matrix:
|
||||||
# Python values need to match the matrix stategy in the artifact build jobs above
|
# Python values need to match the matrix stategy in the artifact build jobs above
|
||||||
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
|
# See exlusions for what is fully tested
|
||||||
|
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
|
||||||
python: ["3.10",]
|
python: ["3.10",]
|
||||||
cuda: ["12.3", "12.1"]
|
cuda: ["12.1","12.3","12.8"]
|
||||||
enable-x64: [1, 0]
|
enable-x64: [1, 0]
|
||||||
exclude:
|
exclude:
|
||||||
# Run only a single configuration on H100 to save resources
|
# L4 does not run on cuda 12.8 but tests other configs
|
||||||
|
- runner: "linux-x86-g2-48-l4-4gpu"
|
||||||
|
cuda: "12.8"
|
||||||
|
# H100 runs only a single config, CUDA 12.3 Enable x64 1
|
||||||
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||||
|
cuda: "12.8"
|
||||||
- runner: "linux-x86-a3-8g-h100-8gpu"
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||||
python: "3.10"
|
|
||||||
cuda: "12.1"
|
cuda: "12.1"
|
||||||
- runner: "linux-x86-a3-8g-h100-8gpu"
|
- runner: "linux-x86-a3-8g-h100-8gpu"
|
||||||
python: "3.10"
|
enable-x64: "0"
|
||||||
enable-x64: 0
|
# B200 runs only a single config, CUDA 12.8 Enable x64 1
|
||||||
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||||
|
enable-x64: "0"
|
||||||
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||||
|
cuda: "12.1"
|
||||||
|
- runner: "linux-x86-a4-224-b200-1gpu"
|
||||||
|
cuda: "12.3"
|
||||||
|
|
||||||
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
|
||||||
with:
|
with:
|
||||||
runner: ${{ matrix.runner }}
|
runner: ${{ matrix.runner }}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user