diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml
index c3420d6b5..ae74da53e 100644
--- a/.github/workflows/pytest_cuda.yml
+++ b/.github/workflows/pytest_cuda.yml
@@ -54,7 +54,8 @@ jobs:
     runs-on: ${{ inputs.runner }}
     # TODO: Update to the generic ML ecosystem test containers when they are ready.
     container:  ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
-                (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
+                (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
+                (contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
     name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
 
     env:
diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml
index fc4321724..ecdf43b13 100644
--- a/.github/workflows/wheel_tests_continuous.yml
+++ b/.github/workflows/wheel_tests_continuous.yml
@@ -110,18 +110,30 @@ jobs:
         fail-fast: false # don't cancel all jobs on failure
         matrix:
           # Python values need to match the matrix stategy in the artifact build jobs above
-          runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
+          # See exlusions for what is fully tested
+          runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
           python: ["3.10",]
-          cuda: ["12.3", "12.1"]
+          cuda: ["12.1","12.3","12.8"]
           enable-x64: [1, 0]
           exclude:
-            # Run only a single configuration on H100 to save resources
+            # L4 does not run on cuda 12.8 but tests other configs
+            - runner: "linux-x86-g2-48-l4-4gpu"
+              cuda: "12.8"
+            # H100 runs only a single config, CUDA 12.3 Enable x64 1
+            - runner: "linux-x86-a3-8g-h100-8gpu"
+              cuda: "12.8"
             - runner: "linux-x86-a3-8g-h100-8gpu"
-              python: "3.10"
               cuda: "12.1"
             - runner: "linux-x86-a3-8g-h100-8gpu"
-              python: "3.10"
-              enable-x64: 0
+              enable-x64: "0"
+            # B200 runs only a single config, CUDA 12.8 Enable x64 1
+            - runner: "linux-x86-a4-224-b200-1gpu"
+              enable-x64: "0"
+            - runner: "linux-x86-a4-224-b200-1gpu"
+              cuda: "12.1"
+            - runner: "linux-x86-a4-224-b200-1gpu"
+              cuda: "12.3"
+
     name: "Pytest CUDA (JAX artifacts version =  ${{ format('{0}', 'head') }})"
     with:
       runner: ${{ matrix.runner }}