diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 717e91ca4..037704462 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -513,8 +513,6 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:absl_status_casters", - "//jaxlib:kernel_nanobind_helpers", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -522,6 +520,7 @@ cc_library( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", + "@com_google_absl//absl/base:dynamic_annotations", ], ) diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index e517b8c4f..d42199d37 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/base/dynamic_annotations.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" @@ -30,39 +31,45 @@ namespace jax::cuda { int CudaRuntimeGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CudaDriverGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } uint32_t CuptiGetVersion() { uint32_t version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CufftGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CusolverGetVersion() { int version; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CublasGetVersion() { int version; - // NVIDIA promise that it's safe to parse nullptr as the handle to this + // NVIDIA promise that it's safe to pass a null pointer as the handle to this // function. JAX_THROW_IF_ERROR( JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } @@ -73,6 +80,9 @@ int CusparseGetVersion() { JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&patch, sizeof patch); return major * 1000 + minor * 100 + patch; } size_t CudnnGetVersion() { @@ -82,6 +92,7 @@ size_t CudnnGetVersion() { if (version == 0) { throw std::runtime_error("cuDNN not found."); } + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version); return version; } int CudaComputeCapability(int device) { @@ -91,6 +102,8 @@ int CudaComputeCapability(int device) { &major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute( &minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor); return major * 10 + minor; } @@ -99,6 +112,7 @@ int CudaDeviceCount() { JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count))); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&device_count, sizeof device_count); return device_count; }