mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00

We need to set them as `min(num_cpu_cores, num_gpus * max_tests_per_gpu, total ram in GB/6)` where max_tests_per_gpu = (GPU memory / 2GB) PiperOrigin-RevId: 731730857