cuInit before querying compute capability

This commit is contained in:
Olli Lupton 2024-04-04 15:27:57 +00:00
parent 498e81ab10
commit c97d955771

View File

@ -86,6 +86,7 @@ size_t CudnnGetVersion() {
}
int CudaComputeCapability(int device) {
int major, minor;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuInit(0)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
@ -102,4 +103,4 @@ int CudaDeviceCount() {
}
} // namespace jax::cuda
} // namespace jax::cuda