mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add minimum version checks for cublas and cusparse.
Split code to determine CUDA library versions out of py_extension() module and into a cc_library(), because it fixes a linking problem in Google's build. (Long story, not worth it.) Fixes https://github.com/google/jax/issues/8289 PiperOrigin-RevId: 583544218
This commit is contained in:
parent
df9dd53c16
commit
41f0b336e3
@ -216,10 +216,7 @@ register_backend_factory(
|
||||
|
||||
|
||||
def _check_cuda_versions():
|
||||
# TODO(phawkins): remove the test for None cuda_versions after jaxlib 0.4.17
|
||||
# is the minimum.
|
||||
if cuda_versions is None:
|
||||
return
|
||||
assert cuda_versions is not None
|
||||
|
||||
def _version_check(name, get_version, get_build_version,
|
||||
scale_for_comparison=1):
|
||||
@ -256,9 +253,14 @@ def _check_cuda_versions():
|
||||
scale_for_comparison=100)
|
||||
_version_check("cuPTI", cuda_versions.cupti_get_version,
|
||||
cuda_versions.cupti_build_version)
|
||||
# TODO(phawkins): ideally we'd check cublas and cusparse here also, but their
|
||||
# "get version" APIs require initializing those libraries, which we don't want
|
||||
# to do here.
|
||||
_version_check("cuBLAS", cuda_versions.cublas_get_version,
|
||||
cuda_versions.cublas_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
|
||||
cuda_versions.cusparse_build_version,
|
||||
# Ignore patch versions.
|
||||
scale_for_comparison=100)
|
||||
|
||||
|
||||
def make_gpu_client(
|
||||
|
@ -458,6 +458,30 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "versions_helpers",
|
||||
srcs = ["versions_helpers.cc"],
|
||||
hdrs = ["versions_helpers.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:absl_status_casters",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@tsl//tsl/cuda:cublas",
|
||||
"@tsl//tsl/cuda:cudart",
|
||||
"@tsl//tsl/cuda:cudnn",
|
||||
"@tsl//tsl/cuda:cufft",
|
||||
"@tsl//tsl/cuda:cupti",
|
||||
"@tsl//tsl/cuda:cusolver",
|
||||
"@tsl//tsl/cuda:cusparse",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_versions",
|
||||
srcs = ["versions.cc"],
|
||||
@ -482,6 +506,7 @@ pybind_extension(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
":versions_helpers",
|
||||
"//jaxlib:absl_status_casters",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@tsl//tsl/cuda:cublas",
|
||||
|
@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/versions_helpers.h"
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax::cuda {
|
||||
@ -22,40 +23,6 @@ namespace {
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
#if CUDA_VERSION < 11080
|
||||
#error "JAX requires CUDA 11.8 or newer."
|
||||
#endif // CUDA_VERSION < 11080
|
||||
|
||||
int CudaRuntimeGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CudaDriverGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
uint32_t CuptiGetVersion() {
|
||||
uint32_t version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CufftGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CusolverGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
NB_MODULE(_versions, m) {
|
||||
// Nanobind's leak checking sometimes returns false positives for this file.
|
||||
// The problem appears related to forming a closure of a nanobind function.
|
||||
@ -70,14 +37,14 @@ NB_MODULE(_versions, m) {
|
||||
m.def("cusolver_build_version", []() { return CUSOLVER_VERSION; });
|
||||
m.def("cusparse_build_version", []() { return CUSPARSE_VERSION; });
|
||||
|
||||
// TODO(phawkins): annoyingly cublas and cusparse have "get version" APIs that
|
||||
// require the library to be initialized.
|
||||
m.def("cuda_runtime_get_version", &CudaRuntimeGetVersion);
|
||||
m.def("cuda_driver_get_version", &CudaDriverGetVersion);
|
||||
m.def("cudnn_get_version", &cudnnGetVersion);
|
||||
m.def("cupti_get_version", &CuptiGetVersion);
|
||||
m.def("cufft_get_version", &CufftGetVersion);
|
||||
m.def("cusolver_get_version", &CusolverGetVersion);
|
||||
m.def("cublas_get_version", &CublasGetVersion);
|
||||
m.def("cusparse_get_version", &CusparseGetVersion);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
76
jaxlib/cuda/versions_helpers.cc
Normal file
76
jaxlib/cuda/versions_helpers.cc
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2023 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/versions_helpers.h"
|
||||
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax::cuda {
|
||||
|
||||
#if CUDA_VERSION < 11080
|
||||
#error "JAX requires CUDA 11.8 or newer."
|
||||
#endif // CUDA_VERSION < 11080
|
||||
|
||||
int CudaRuntimeGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CudaDriverGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
uint32_t CuptiGetVersion() {
|
||||
uint32_t version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CufftGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CusolverGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CublasGetVersion() {
|
||||
int version;
|
||||
// NVIDIA promise that it's safe to parse nullptr as the handle to this
|
||||
// function.
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version)));
|
||||
return version;
|
||||
}
|
||||
|
||||
int CusparseGetVersion() {
|
||||
// cusparseGetVersion is unhappy if passed a null library handle. But
|
||||
// cusparseGetProperty doesn't require one.
|
||||
int major, minor, patch;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch)));
|
||||
return major * 1000 + minor * 100 + patch;
|
||||
}
|
||||
|
||||
} // namespace jax::cuda
|
33
jaxlib/cuda/versions_helpers.h
Normal file
33
jaxlib/cuda/versions_helpers.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2023 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_CUDA_VERSIONS_HELPERS_H_
|
||||
#define JAXLIB_CUDA_VERSIONS_HELPERS_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace jax::cuda {
|
||||
|
||||
int CudaRuntimeGetVersion();
|
||||
int CudaDriverGetVersion();
|
||||
uint32_t CuptiGetVersion();
|
||||
int CufftGetVersion();
|
||||
int CusolverGetVersion();
|
||||
int CublasGetVersion();
|
||||
int CusparseGetVersion();
|
||||
|
||||
} // namespace jax::cuda
|
||||
|
||||
#endif // JAXLIB_CUDA_VERSIONS_HELPERS_H_
|
Loading…
x
Reference in New Issue
Block a user