Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.

PiperOrigin-RevId: 641879836
This commit is contained in:
Thomas Köppe 2024-06-10 06:20:21 -07:00 committed by jax authors
parent 991797a8a9
commit cd93b46df4
2 changed files with 16 additions and 3 deletions

View File

@ -513,8 +513,6 @@ cc_library(
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:absl_status_casters",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cudnn",
@ -522,6 +520,7 @@ cc_library(
"@xla//xla/tsl/cuda:cupti",
"@xla//xla/tsl/cuda:cusolver",
"@xla//xla/tsl/cuda:cusparse",
"@com_google_absl//absl/base:dynamic_annotations",
],
)

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstddef>
#include <stdexcept>
#include "absl/base/dynamic_annotations.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
@ -30,39 +31,45 @@ namespace jax::cuda {
int CudaRuntimeGetVersion() {
int version;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
return version;
}
int CudaDriverGetVersion() {
int version;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
return version;
}
uint32_t CuptiGetVersion() {
uint32_t version;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
return version;
}
int CufftGetVersion() {
int version;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
return version;
}
int CusolverGetVersion() {
int version;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
return version;
}
int CublasGetVersion() {
int version;
// NVIDIA promise that it's safe to parse nullptr as the handle to this
// NVIDIA promise that it's safe to pass a null pointer as the handle to this
// function.
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
return version;
}
@ -73,6 +80,9 @@ int CusparseGetVersion() {
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&patch, sizeof patch);
return major * 1000 + minor * 100 + patch;
}
size_t CudnnGetVersion() {
@ -82,6 +92,7 @@ size_t CudnnGetVersion() {
if (version == 0) {
throw std::runtime_error("cuDNN not found.");
}
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
return version;
}
int CudaComputeCapability(int device) {
@ -91,6 +102,8 @@ int CudaComputeCapability(int device) {
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor);
return major * 10 + minor;
}
@ -99,6 +112,7 @@ int CudaDeviceCount() {
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count)));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&device_count, sizeof device_count);
return device_count;
}