mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.
PiperOrigin-RevId: 641879836
This commit is contained in:
parent
991797a8a9
commit
cd93b46df4
@ -513,8 +513,6 @@ cc_library(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:absl_status_casters",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cudnn",
|
||||
@ -522,6 +520,7 @@ cc_library(
|
||||
"@xla//xla/tsl/cuda:cupti",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
"@xla//xla/tsl/cuda:cusparse",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
@ -30,39 +31,45 @@ namespace jax::cuda {
|
||||
int CudaRuntimeGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
|
||||
return version;
|
||||
}
|
||||
|
||||
int CudaDriverGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
|
||||
return version;
|
||||
}
|
||||
|
||||
uint32_t CuptiGetVersion() {
|
||||
uint32_t version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
|
||||
return version;
|
||||
}
|
||||
|
||||
int CufftGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
|
||||
return version;
|
||||
}
|
||||
|
||||
int CusolverGetVersion() {
|
||||
int version;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
|
||||
return version;
|
||||
}
|
||||
|
||||
int CublasGetVersion() {
|
||||
int version;
|
||||
// NVIDIA promise that it's safe to parse nullptr as the handle to this
|
||||
// NVIDIA promise that it's safe to pass a null pointer as the handle to this
|
||||
// function.
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
|
||||
return version;
|
||||
}
|
||||
|
||||
@ -73,6 +80,9 @@ int CusparseGetVersion() {
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&patch, sizeof patch);
|
||||
return major * 1000 + minor * 100 + patch;
|
||||
}
|
||||
size_t CudnnGetVersion() {
|
||||
@ -82,6 +92,7 @@ size_t CudnnGetVersion() {
|
||||
if (version == 0) {
|
||||
throw std::runtime_error("cuDNN not found.");
|
||||
}
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
|
||||
return version;
|
||||
}
|
||||
int CudaComputeCapability(int device) {
|
||||
@ -91,6 +102,8 @@ int CudaComputeCapability(int device) {
|
||||
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
|
||||
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor);
|
||||
return major * 10 + minor;
|
||||
}
|
||||
|
||||
@ -99,6 +112,7 @@ int CudaDeviceCount() {
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count)));
|
||||
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&device_count, sizeof device_count);
|
||||
return device_count;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user