mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
cuInit before querying compute capability
This commit is contained in:
parent
498e81ab10
commit
c97d955771
@ -86,6 +86,7 @@ size_t CudnnGetVersion() {
|
||||
}
|
||||
int CudaComputeCapability(int device) {
|
||||
int major, minor;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuInit(0)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
|
||||
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
|
||||
@ -102,4 +103,4 @@ int CudaDeviceCount() {
|
||||
}
|
||||
|
||||
|
||||
} // namespace jax::cuda
|
||||
} // namespace jax::cuda
|
||||
|
Loading…
x
Reference in New Issue
Block a user