Add a compile-time version test that verifies CUDA is version 11.8 or newer.

Issue https://github.com/google/jax/issues/17829

PiperOrigin-RevId: 569302585
This commit is contained in:
Peter Hawkins 2023-09-28 15:10:08 -07:00 committed by jax authors
parent 528b035ee5
commit 2eca5b34b3

View File

@ -22,6 +22,10 @@ namespace {
namespace nb = nanobind;
#if CUDA_VERSION < 11080
#error "JAX requires CUDA 11.8 or newer."
#endif // CUDA_VERSION < 11080
int CudaRuntimeGetVersion() {
int version;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version)));