From 45b871950e5a8355060244fc75989d4a67bbdc7a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 26 Aug 2024 17:03:27 -0700 Subject: [PATCH] Fix a number of minor problems in the ROCM build. Change in preparation for adding more presubmits for AMD ROCM. PiperOrigin-RevId: 667766343 --- jaxlib/cuda/BUILD | 2 + jaxlib/gpu/linalg_kernels.cc | 7 +-- jaxlib/gpu/linalg_kernels.cu.cc | 62 ++++++++++++---------- jaxlib/gpu/linalg_kernels.h | 10 ++-- jaxlib/gpu/solver_kernels.cc | 8 +-- jaxlib/gpu/solver_kernels_ffi.cc | 19 +++++-- jaxlib/gpu/sparse.cc | 84 ++++++++++++++++-------------- jaxlib/gpu/sparse_kernels.cc | 20 ++++--- jaxlib/gpu/sparse_kernels.h | 2 +- jaxlib/gpu/triton_kernels.cc | 22 ++++++-- jaxlib/gpu/vendor.h | 30 ++++++----- jaxlib/rocm/{BUILD.bazel => BUILD} | 13 ++++- jaxlib/rocm_plugin_extension.cc | 12 ++--- 13 files changed, 173 insertions(+), 118 deletions(-) rename jaxlib/rocm/{BUILD.bazel => BUILD} (96%) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index a7a47f431..72db9868e 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -295,6 +295,7 @@ cc_library( "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/service:custom_call_status", @@ -323,6 +324,7 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":cusparse_kernels", + "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index b22248409..039a9b5c1 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -40,7 +40,8 @@ absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers, auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); const CholeskyUpdateDescriptor& d = **s; - LaunchCholeskyUpdateKernel(stream, buffers, d); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(LaunchCholeskyUpdateKernel(stream, buffers, d))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); return absl::OkStatus(); } @@ -98,8 +99,8 @@ ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in, gpuMemcpyDeviceToDevice, stream))); } for (auto n = 0; n < batch; ++n) { - LaunchCholeskyUpdateFfiKernel(stream, matrix, vector, size, - is_single_precision); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(LaunchCholeskyUpdateFfiKernel( + stream, matrix, vector, size, is_single_precision))); FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); } return ffi::Error::Success(); diff --git a/jaxlib/gpu/linalg_kernels.cu.cc b/jaxlib/gpu/linalg_kernels.cu.cc index 50c653d8c..7f87d66fb 100644 --- a/jaxlib/gpu/linalg_kernels.cu.cc +++ b/jaxlib/gpu/linalg_kernels.cu.cc @@ -67,8 +67,9 @@ __global__ void CholeskyUpdateKernel(T* rMatrix, T* uVector, int nSize) { } // namespace template -void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, - int grid_dim, int block_dim, int nSize) { +gpuError_t LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, + int grid_dim, int block_dim, + int nSize) { T* rMatrix = reinterpret_cast(buffers[2]); T* uVector = reinterpret_cast(buffers[3]); @@ -77,39 +78,40 @@ void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, reinterpret_cast(&uVector), reinterpret_cast(&nSize), }; - gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, - block_dim, arg_ptrs, - /*dynamic_shared_mem_bytes=*/0, stream); + return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/0, stream); } -void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, - CholeskyUpdateDescriptor descriptor) { +gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, + CholeskyUpdateDescriptor descriptor) { int nSize = descriptor.matrix_size; LinalgType type = descriptor.linalg_type; int dev = 0; gpuDeviceProp deviceProp; - gpuGetDeviceProperties(&deviceProp, dev); + gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev); + if (err != gpuSuccess) { + return err; + } int block_dim = deviceProp.maxThreadsPerBlock; int grid_dim = deviceProp.multiProcessorCount; switch (type) { case LinalgType::F64: - LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, - block_dim, nSize); - break; + return LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, + block_dim, nSize); case LinalgType::F32: - LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, - block_dim, nSize); - break; + return LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, + block_dim, nSize); } } template -void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, - void* vector, int grid_dim, - int block_dim, int nSize) { +gpuError_t LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, + void* vector, int grid_dim, + int block_dim, int nSize) { T* rMatrix = reinterpret_cast(matrix); T* uVector = reinterpret_cast(vector); @@ -118,26 +120,30 @@ void LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, reinterpret_cast(&uVector), reinterpret_cast(&nSize), }; - gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, - block_dim, arg_ptrs, - /*dynamic_shared_mem_bytes=*/0, stream); + return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/0, stream); } -void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, - void* vector, int size, - bool is_single_precision) { +gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision) { int dev = 0; gpuDeviceProp deviceProp; - gpuGetDeviceProperties(&deviceProp, dev); + gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev); + if (err != gpuSuccess) { + return err; + } + int block_dim = deviceProp.maxThreadsPerBlock; int grid_dim = deviceProp.multiProcessorCount; if (is_single_precision) { - LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, grid_dim, - block_dim, size); + return LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, + grid_dim, block_dim, size); } else { - LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, grid_dim, - block_dim, size); + return LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, + grid_dim, block_dim, size); } } diff --git a/jaxlib/gpu/linalg_kernels.h b/jaxlib/gpu/linalg_kernels.h index 47ada398c..2c41b7f43 100644 --- a/jaxlib/gpu/linalg_kernels.h +++ b/jaxlib/gpu/linalg_kernels.h @@ -36,15 +36,15 @@ struct CholeskyUpdateDescriptor { std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix }; -void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, - CholeskyUpdateDescriptor descriptor); +gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, + CholeskyUpdateDescriptor descriptor); void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); -void LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, - void* vector, int size, - bool is_single_precision); +gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision); XLA_FFI_DECLARE_HANDLER_SYMBOL(CholeskyUpdateFfi); void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8d90c7053..8c22dfcdb 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" @@ -421,9 +421,9 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers, int output_idx = 1; // with static shapes buffers[1] is the first output if (d.batch == -1) { // the batch is passed as a second operand - gpuMemcpyAsync((void*)&batch, - reinterpret_cast(buffers[1]), - sizeof(batch), gpuMemcpyDeviceToHost, stream); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( + (void*)&batch, reinterpret_cast(buffers[1]), + sizeof(batch), gpuMemcpyDeviceToHost, stream))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); output_idx = 2; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 2b1f55529..6e988a6ca 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -61,6 +61,17 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return impl(__VA_ARGS__); \ } +#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ + if (dataType == ffi::F32) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::F64) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::C64) { \ + return impl(__VA_ARGS__); \ + } else if (dataType == ffi::C128) { \ + return impl(__VA_ARGS__); \ + } + // LU decomposition: getrf namespace { @@ -189,8 +200,8 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ipiv->dimensions(), {batch, std::min(rows, cols)}, "ipiv", "getrf")); FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "getrf")); if (batch > 1 && rows == cols && rows / batch <= 128) { - SOLVER_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, out, - ipiv, info); + SOLVER_BLAS_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, + out, ipiv, info); } else { SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out, ipiv, info); @@ -345,8 +356,8 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, FFI_RETURN_IF_ERROR(CheckShape( tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqrf")); if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) { - SOLVER_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, scratch, - a, out, tau); + SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, + scratch, a, out, tau); } else { SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out, tau); diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index b9eb51388..2eeb94e30 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" @@ -57,14 +58,19 @@ gpusparseIndexType_t DtypeToCuSparseIndexType(const dtype& np_type) { gpuDataType DtypeToCudaDataType(const dtype& np_type) { static auto* types = new absl::flat_hash_map, gpuDataType>({ - {{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F}, - {{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F}, - {{'c', 16}, GPU_C_64F}, + {{'f', 2}, GPU_R_16F}, + {{'c', 4}, GPU_C_16F}, + {{'f', 4}, GPU_R_32F}, + {{'c', 8}, GPU_C_32F}, + {{'f', 8}, GPU_R_64F}, + {{'c', 16}, GPU_C_64F}, #ifdef JAX_GPU_CUDA - {{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U}, - {{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U}, + {{'i', 1}, CUDA_R_8I}, + {{'u', 1}, CUDA_R_8U}, + {{'i', 4}, CUDA_R_32I}, + {{'u', 4}, CUDA_R_32U}, #if JAX_GPU_HAVE_SPARSE - {{'V', 2}, CUDA_R_16BF}, + {{'V', 2}, CUDA_R_16BF}, #endif // JAX_GPU_HAVE_SPARSE #endif // JAX_GPU_CUDA }); @@ -78,9 +84,8 @@ gpuDataType DtypeToCudaDataType(const dtype& np_type) { } // Returns the descriptor for a Sparse matrix. SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype, - const dtype& index_dtype, - int rows, int cols, int nnz, - int batch_count, + const dtype& index_dtype, int rows, + int cols, int nnz, int batch_count, int batch_stride) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype); @@ -89,16 +94,15 @@ SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype, } // Returns the descriptor for a Dense matrix. -DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, - int rows, int cols, int batch_count, +DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, int rows, + int cols, int batch_count, int batch_stride) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride}; } // Returns the descriptor for a Dense vector. -DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, - int size) { +DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, int size) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseVecDescriptor{value_type, size}; } @@ -107,9 +111,10 @@ DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, // CsrToDense: Convert CSR matrix to dense matrix // Returns the descriptor for a Sparse matrix. -std::pair BuildCsrToDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { +std::pair BuildCsrToDenseDescriptor(const dtype& data_dtype, + const dtype& index_dtype, + int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -185,8 +190,8 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, // Returns the descriptor for a CsrFromDense operation. std::pair BuildCsrFromDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -261,9 +266,8 @@ void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, // Returns the descriptor for a CsrMatvec operation. std::pair BuildCsrMatvecDescriptor( - const dtype& data_dtype, const dtype& x_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -292,7 +296,7 @@ std::pair BuildCsrMatvecDescriptor( JAX_THROW_IF_ERROR( JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - SparseConst alpha = ConstOne(y.type); + SparseConst alpha = ValueOrThrow(ConstOne(y.type)); SparseConst beta = ConstZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, @@ -309,9 +313,9 @@ std::pair BuildCsrMatvecDescriptor( // Returns the descriptor for a CsrMatmat operation. std::pair BuildCsrMatmatDescriptor( - const dtype& data_dtype, const dtype& b_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, + bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -344,7 +348,7 @@ std::pair BuildCsrMatmatDescriptor( JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, empty, C.type, GPUSPARSE_ORDER_ROW))); size_t buffer_size; - SparseConst alpha = ConstOne(C.type); + SparseConst alpha = ValueOrThrow(ConstOne(C.type)); SparseConst beta = ConstZero(C.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, @@ -360,9 +364,10 @@ std::pair BuildCsrMatmatDescriptor( // CooToDense: Convert COO matrix to dense matrix // Returns the descriptor for a CooToDense operation. -std::pair BuildCooToDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { +std::pair BuildCooToDenseDescriptor(const dtype& data_dtype, + const dtype& index_dtype, + int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -398,8 +403,8 @@ std::pair BuildCooToDenseDescriptor( // Returns the descriptor for a CooFromDense operation. std::pair BuildCooFromDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -434,9 +439,8 @@ std::pair BuildCooFromDenseDescriptor( // Returns the descriptor for a CooMatvec operation. std::pair BuildCooMatvecDescriptor( - const dtype& data_dtype, const dtype& x_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -465,7 +469,7 @@ std::pair BuildCooMatvecDescriptor( JAX_THROW_IF_ERROR( JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - SparseConst alpha = ConstOne(y.type); + SparseConst alpha = ValueOrThrow(ConstOne(y.type)); SparseConst beta = ConstZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, @@ -482,10 +486,10 @@ std::pair BuildCooMatvecDescriptor( // Returns the descriptor for a CooMatmat operation. std::pair BuildCooMatmatDescriptor( - const dtype& data_dtype, const dtype& b_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose, int batch_count, - int lhs_batch_stride, int rhs_batch_stride) { + const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, + bool transpose, int batch_count, int lhs_batch_stride, + int rhs_batch_stride) { // Three batch modes are supported, C_i = A_i B, C_i = A B_i, and // Ci = A_i B_i, where `i` denotes the batch dimension. // All three matrices A, B, and C must have the same batch count. @@ -535,7 +539,7 @@ std::pair BuildCooMatmatDescriptor( JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch( mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride))); size_t buffer_size; - SparseConst alpha = ConstOne(C.type); + SparseConst alpha = ValueOrThrow(ConstOne(C.type)); SparseConst beta = ConstZero(C.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 93c6aef17..a44d4b331 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" @@ -58,7 +59,7 @@ SparseConst ConstZero(gpuDataType type) { return c; } -SparseConst ConstOne(gpuDataType type) { +absl::StatusOr ConstOne(gpuDataType type) { SparseConst c; std::memset(&c, 0, sizeof(c)); switch (type) { @@ -138,6 +139,9 @@ SparseConst ConstOne(gpuDataType type) { case GPU_C_64F: c.f64[0] = 1.0; break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported data type: ", type)); } return c; } @@ -248,7 +252,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.y.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type)); SparseConst beta = ConstZero(d.y.type); gpusparseSpMatDescr_t mat_a = 0; @@ -305,7 +309,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.C.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type)); SparseConst beta = ConstZero(d.C.type); gpusparseSpMatDescr_t mat_a = 0; @@ -446,7 +450,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.y.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type)); SparseConst beta = ConstZero(d.y.type); gpusparseSpMatDescr_t mat_a = 0; @@ -502,7 +506,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.C.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type)); SparseConst beta = ConstZero(d.C.type); gpusparseSpMatDescr_t mat_a = 0; @@ -567,7 +571,7 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, T* du = static_cast(buffers[2]); T* B = static_cast(buffers[3]); T* X = static_cast(buffers[4]); - void* buffer = static_cast(buffers[5]); + void* buffer = static_cast(buffers[5]); // The solution X is written in place to B. We need to therefore copy the // contents of B into the output buffer X and pass that into the kernel as B. @@ -581,8 +585,8 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); } for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(computeGtsv2( - handle.get(), m, n, dl, d, du, X, ldb, buffer))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer))); dl += m; d += m; du += m; diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 2180767b0..48433b3d6 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -51,7 +51,7 @@ union SparseConst { }; SparseConst ConstZero(gpuDataType type); -SparseConst ConstOne(gpuDataType type); +absl::StatusOr ConstOne(gpuDataType type); struct SparseMatDescriptor { gpuDataType value_type; diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index c96c6b5c5..c4a9af5ff 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -34,7 +34,11 @@ #ifdef JAX_GPU_CUDA #include "xla/stream_executor/cuda/cuda_asm_compiler.h" -#endif +#endif // JAX_GPU_CUDA + +#ifdef JAX_GPU_HIP +#include "tsl/platform/env.h" +#endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) @@ -44,7 +48,12 @@ namespace { constexpr float kBenchmarkTimeMillis = 10.; struct gpuModuleDeleter { - void operator()(gpuModule_t module) { gpuModuleUnload(module); } + void operator()(gpuModule_t module) { + absl::Status status = JAX_AS_STATUS(gpuModuleUnload(module)); + if (!status.ok()) { + LOG(WARNING) << "Failed to unload GPU module: " << status; + } + } }; using OwnedGPUmodule = @@ -52,11 +61,11 @@ using OwnedGPUmodule = absl::StatusOr GetStreamDevice(gpuStream_t stream) { gpuDevice_t device; - gpuContext_t context; #ifdef JAX_GPU_HIP int device_id = gpuGetStreamDeviceId(stream); GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id)); #else // JAX_GPU_CUDA + gpuContext_t context; GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context)); GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; @@ -210,7 +219,12 @@ class ModuleImage { } GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); - absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; + absl::Cleanup ctx_restorer = [] { + absl::Status status = JAX_AS_STATUS(gpuCtxPopCurrent(nullptr)); + if (!status.ok()) { + LOG(WARNING) << "Failed to pop GPU context: " << status; + } + }; gpuModule_t module; GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data())); diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index ef635bebd..077d3bb54 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -22,18 +22,20 @@ limitations under the License. #if defined(JAX_GPU_CUDA) -#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cooperative_groups.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export -#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export +// IWYU pragma: begin_exports +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" +#include "third_party/gpus/cuda/include/cooperative_groups.h" +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_fp8.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cufft.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" +#include "third_party/gpus/cuda/include/cusolver_common.h" +#include "third_party/gpus/cuda/include/cusparse.h" +#include "third_party/gpus/cudnn/cudnn.h" +// IWYU pragma: end_exports #if CUDA_VERSION < 11080 #error "JAX requires CUDA 11.8 or newer." @@ -305,11 +307,13 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) -#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h" +// IWYU pragma: begin_exports +#include "rocm/include/hip/hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" #include "rocm/include/hipblas/hipblas.h" #include "rocm/include/hipsolver/hipsolver.h" #include "rocm/include/hipsparse/hipsparse.h" +// IWYU pragma: end_exports #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD similarity index 96% rename from jaxlib/rocm/BUILD.bazel rename to jaxlib/rocm/BUILD index 58dfd076d..ce856ae5f 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD @@ -14,6 +14,7 @@ # AMD HIP kernels +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_rocm_is_configured", @@ -23,7 +24,10 @@ load( licenses(["notice"]) -package(default_visibility = ["//:__subpackages__"]) +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) cc_library( name = "hip_vendor", @@ -203,6 +207,7 @@ pybind_extension( ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", @@ -223,6 +228,7 @@ cc_library( "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", @@ -243,6 +249,7 @@ pybind_extension( ":hip_gpu_kernel_helpers", ":hip_vendor", ":hipsparse_kernels", + "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -277,6 +284,7 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -376,14 +384,15 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/gpu:asm_compiler", "@xla//xla/tsl/util:env_var", ], ) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc index dde4e57a9..8a732380d 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm_plugin_extension.cc @@ -35,8 +35,8 @@ namespace nb = nanobind; namespace xla { namespace { absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, - nb::capsule fn, int api_version, - XLA_FFI_Handler_Traits traits) { + nb::capsule fn, int api_version, + XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { return Unimplemented("The plugin does not have extension."); } @@ -139,11 +139,11 @@ NB_MODULE(rocm_plugin_extension, m) { void* data_ptr = reinterpret_cast(data_value); hipError_t result = hipPointerGetAttribute(static_cast(&device_ordinal), - HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - reinterpret_cast(data_ptr)); + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); if (result != hipSuccess) { - LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr - << ". Error: " << ToString(result); + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " + << data_ptr << ". Error: " << ToString(result); } return device_ordinal; },