This commit is contained in:
Meekail Zain 2024-03-03 19:57:26 +00:00
parent 0b70244b1c
commit 9fff9aeb69
4 changed files with 39 additions and 3 deletions

View File

@ -62,6 +62,7 @@ traceback_util.register_exclusion(__file__)
XlaBackend = xla_client.Client
MIN_COMPUTE_CAPABILITY = 52
# TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = config.DEFINE_string(
@ -252,6 +253,19 @@ register_backend_factory(
)
def _check_cuda_compute_capability(devices_to_check):
for idx in devices_to_check:
compute_cap = cuda_versions.cuda_compute_capability(idx)
if compute_cap < MIN_COMPUTE_CAPABILITY:
warnings.warn(
f"Device {idx} has CUDA compute capability {compute_cap/10} which is "
"lower than the minimum supported compute capability "
f"{MIN_COMPUTE_CAPABILITY/10}. See "
"https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for "
"more details",
RuntimeWarning
)
def _check_cuda_versions():
assert cuda_versions is not None
@ -311,15 +325,16 @@ def make_gpu_client(
if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")}
if platform_name == "cuda":
_check_cuda_versions()
use_mock_gpu_client = _USE_MOCK_GPU_CLIENT.value
num_nodes = (
_MOCK_NUM_GPUS.value
if use_mock_gpu_client
else distributed.global_state.num_processes
)
if platform_name == "cuda":
_check_cuda_versions()
devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count())
_check_cuda_compute_capability(devices_to_check)
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,

View File

@ -45,6 +45,8 @@ NB_MODULE(_versions, m) {
m.def("cusolver_get_version", &CusolverGetVersion);
m.def("cublas_get_version", &CublasGetVersion);
m.def("cusparse_get_version", &CusparseGetVersion);
m.def("cuda_compute_capability", &CudaComputeCapability);
m.def("cuda_device_count", &CudaDeviceCount);
}
} // namespace

View File

@ -84,5 +84,22 @@ size_t CudnnGetVersion() {
}
return version;
}
int CudaComputeCapability(int device) {
int major, minor;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)));
return major * 10 + minor;
}
int CudaDeviceCount() {
int device_count = 0;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count)));
return device_count;
}
} // namespace jax::cuda

View File

@ -29,6 +29,8 @@ int CusolverGetVersion();
int CublasGetVersion();
int CusparseGetVersion();
size_t CudnnGetVersion();
int CudaComputeCapability(int);
int CudaDeviceCount();
} // namespace jax::cuda