diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index e2c4c1d97..eda42fbb8 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -62,6 +62,7 @@ traceback_util.register_exclusion(__file__) XlaBackend = xla_client.Client +MIN_COMPUTE_CAPABILITY = 52 # TODO(phawkins): Remove jax_xla_backend. _XLA_BACKEND = config.DEFINE_string( @@ -252,6 +253,19 @@ register_backend_factory( ) +def _check_cuda_compute_capability(devices_to_check): + for idx in devices_to_check: + compute_cap = cuda_versions.cuda_compute_capability(idx) + if compute_cap < MIN_COMPUTE_CAPABILITY: + warnings.warn( + f"Device {idx} has CUDA compute capability {compute_cap/10} which is " + "lower than the minimum supported compute capability " + f"{MIN_COMPUTE_CAPABILITY/10}. See " + "https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for " + "more details", + RuntimeWarning + ) + def _check_cuda_versions(): assert cuda_versions is not None @@ -311,15 +325,16 @@ def make_gpu_client( if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} - if platform_name == "cuda": - _check_cuda_versions() - use_mock_gpu_client = _USE_MOCK_GPU_CLIENT.value num_nodes = ( _MOCK_NUM_GPUS.value if use_mock_gpu_client else distributed.global_state.num_processes ) + if platform_name == "cuda": + _check_cuda_versions() + devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count()) + _check_cuda_compute_capability(devices_to_check) return xla_client.make_gpu_client( distributed_client=distributed.global_state.client, diff --git a/jaxlib/cuda/versions.cc b/jaxlib/cuda/versions.cc index af6e0da22..8d6577f46 100644 --- a/jaxlib/cuda/versions.cc +++ b/jaxlib/cuda/versions.cc @@ -45,6 +45,8 @@ NB_MODULE(_versions, m) { m.def("cusolver_get_version", &CusolverGetVersion); m.def("cublas_get_version", &CublasGetVersion); m.def("cusparse_get_version", &CusparseGetVersion); + m.def("cuda_compute_capability", &CudaComputeCapability); + m.def("cuda_device_count", &CudaDeviceCount); } } // namespace diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index c9294131e..9ecd9a83c 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -84,5 +84,22 @@ size_t CudnnGetVersion() { } return version; } +int CudaComputeCapability(int device) { + int major, minor; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute( + &major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute( + &minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device))); + return major * 10 + minor; +} + +int CudaDeviceCount() { + int device_count = 0; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count))); + + return device_count; +} + } // namespace jax::cuda \ No newline at end of file diff --git a/jaxlib/cuda/versions_helpers.h b/jaxlib/cuda/versions_helpers.h index 01a8e9dc9..aa8c30976 100644 --- a/jaxlib/cuda/versions_helpers.h +++ b/jaxlib/cuda/versions_helpers.h @@ -29,6 +29,8 @@ int CusolverGetVersion(); int CublasGetVersion(); int CusparseGetVersion(); size_t CudnnGetVersion(); +int CudaComputeCapability(int); +int CudaDeviceCount(); } // namespace jax::cuda