mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update
This commit is contained in:
parent
0b70244b1c
commit
9fff9aeb69
@ -62,6 +62,7 @@ traceback_util.register_exclusion(__file__)
|
||||
|
||||
XlaBackend = xla_client.Client
|
||||
|
||||
MIN_COMPUTE_CAPABILITY = 52
|
||||
|
||||
# TODO(phawkins): Remove jax_xla_backend.
|
||||
_XLA_BACKEND = config.DEFINE_string(
|
||||
@ -252,6 +253,19 @@ register_backend_factory(
|
||||
)
|
||||
|
||||
|
||||
def _check_cuda_compute_capability(devices_to_check):
|
||||
for idx in devices_to_check:
|
||||
compute_cap = cuda_versions.cuda_compute_capability(idx)
|
||||
if compute_cap < MIN_COMPUTE_CAPABILITY:
|
||||
warnings.warn(
|
||||
f"Device {idx} has CUDA compute capability {compute_cap/10} which is "
|
||||
"lower than the minimum supported compute capability "
|
||||
f"{MIN_COMPUTE_CAPABILITY/10}. See "
|
||||
"https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for "
|
||||
"more details",
|
||||
RuntimeWarning
|
||||
)
|
||||
|
||||
def _check_cuda_versions():
|
||||
assert cuda_versions is not None
|
||||
|
||||
@ -311,15 +325,16 @@ def make_gpu_client(
|
||||
if visible_devices != "all":
|
||||
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
||||
|
||||
if platform_name == "cuda":
|
||||
_check_cuda_versions()
|
||||
|
||||
use_mock_gpu_client = _USE_MOCK_GPU_CLIENT.value
|
||||
num_nodes = (
|
||||
_MOCK_NUM_GPUS.value
|
||||
if use_mock_gpu_client
|
||||
else distributed.global_state.num_processes
|
||||
)
|
||||
if platform_name == "cuda":
|
||||
_check_cuda_versions()
|
||||
devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count())
|
||||
_check_cuda_compute_capability(devices_to_check)
|
||||
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
|
@ -45,6 +45,8 @@ NB_MODULE(_versions, m) {
|
||||
m.def("cusolver_get_version", &CusolverGetVersion);
|
||||
m.def("cublas_get_version", &CublasGetVersion);
|
||||
m.def("cusparse_get_version", &CusparseGetVersion);
|
||||
m.def("cuda_compute_capability", &CudaComputeCapability);
|
||||
m.def("cuda_device_count", &CudaDeviceCount);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -84,5 +84,22 @@ size_t CudnnGetVersion() {
|
||||
}
|
||||
return version;
|
||||
}
|
||||
int CudaComputeCapability(int device) {
|
||||
int major, minor;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
|
||||
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
|
||||
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)));
|
||||
return major * 10 + minor;
|
||||
}
|
||||
|
||||
int CudaDeviceCount() {
|
||||
int device_count = 0;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count)));
|
||||
|
||||
return device_count;
|
||||
}
|
||||
|
||||
|
||||
} // namespace jax::cuda
|
@ -29,6 +29,8 @@ int CusolverGetVersion();
|
||||
int CublasGetVersion();
|
||||
int CusparseGetVersion();
|
||||
size_t CudnnGetVersion();
|
||||
int CudaComputeCapability(int);
|
||||
int CudaDeviceCount();
|
||||
|
||||
} // namespace jax::cuda
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user