From c3689699555d63cff39e86b51885fbfd89148915 Mon Sep 17 00:00:00 2001 From: Aden Grue Date: Mon, 9 Aug 2021 15:06:12 -0700 Subject: [PATCH] Use the new "custom call status" facility to report errors in jaxlib PiperOrigin-RevId: 389734200 --- jaxlib/BUILD | 15 +- jaxlib/cublas.cc | 105 +++-- jaxlib/cuda_gpu_kernel_helpers.cc | 64 +-- jaxlib/cuda_gpu_kernel_helpers.h | 43 +- jaxlib/cuda_linalg.py | 4 +- jaxlib/cuda_lu_pivot_kernels.cu.cc | 25 +- jaxlib/cuda_lu_pivot_kernels.h | 4 +- jaxlib/cuda_prng.py | 4 +- jaxlib/cuda_prng_kernels.cu.cc | 25 +- jaxlib/cuda_prng_kernels.h | 3 +- jaxlib/cusolver.cc | 696 +++++++++++++++++------------ jaxlib/cusolver.py | 36 +- jaxlib/cusparse.cc | 659 ++++++++++++++++----------- jaxlib/cusparse.py | 20 +- jaxlib/handle_pool.h | 3 +- jaxlib/kernel_helpers.h | 6 +- jaxlib/rocblas.cc | 408 ++++++++++------- jaxlib/rocm_gpu_kernel_helpers.cc | 16 +- jaxlib/rocm_gpu_kernel_helpers.h | 18 +- jaxlib/rocsolver.py | 28 +- 20 files changed, 1359 insertions(+), 823 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 636c163b0..67ad43d92 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -54,6 +54,7 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base", + "@com_google_absl//absl/status:statusor", ], ) @@ -67,6 +68,7 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", ], ) @@ -82,8 +84,9 @@ cc_library( deps = [ "@org_tensorflow//tensorflow/stream_executor/cuda:cusolver_lib", "@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib", - "@com_google_absl//absl/base", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cublas_headers", @@ -102,6 +105,8 @@ cc_library( deps = [ "@com_google_absl//absl/base", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_config_rocm//rocm:rocm_headers", ], @@ -159,6 +164,7 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":handle_pool", ":kernel_pybind11_helpers", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@org_tensorflow//tensorflow/stream_executor/cuda:cublas_lib", "@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub", "@com_google_absl//absl/algorithm:container", @@ -189,6 +195,7 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":handle_pool", ":kernel_pybind11_helpers", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub", "@org_tensorflow//tensorflow/stream_executor/cuda:cusolver_lib", "@com_google_absl//absl/algorithm:container", @@ -218,6 +225,7 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":handle_pool", ":kernel_pybind11_helpers", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub", "@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib", "@com_google_absl//absl/algorithm:container", @@ -241,6 +249,8 @@ cuda_library( deps = [ ":cuda_gpu_kernel_helpers", ":kernel_helpers", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@com_google_absl//absl/status", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -270,6 +280,8 @@ cuda_library( deps = [ ":cuda_gpu_kernel_helpers", ":kernel_helpers", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + "@com_google_absl//absl/status", "@local_config_cuda//cuda:cuda_headers", ], ) @@ -306,6 +318,7 @@ pybind_extension( ":handle_pool", ":kernel_pybind11_helpers", ":rocm_gpu_kernel_helpers", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", diff --git a/jaxlib/cublas.cc b/jaxlib/cublas.cc index 0e1ce8cf0..927a8abdc 100644 --- a/jaxlib/cublas.cc +++ b/jaxlib/cublas.cc @@ -32,6 +32,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { namespace { @@ -41,18 +42,19 @@ namespace py = pybind11; using BlasHandlePool = HandlePool; template <> -/*static*/ BlasHandlePool::Handle BlasHandlePool::Borrow(cudaStream_t stream) { +/*static*/ absl::StatusOr BlasHandlePool::Borrow( + cudaStream_t stream) { BlasHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); cublasHandle_t handle; if (pool->handles_[stream].empty()) { - JAX_THROW_IF_ERROR(cublasCreate(&handle)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCreate(&handle))); } else { handle = pool->handles_[stream].back(); pool->handles_[stream].pop_back(); } if (stream) { - JAX_THROW_IF_ERROR(cublasSetStream(handle, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSetStream(handle, stream))); } return Handle(pool, handle, stream); } @@ -122,26 +124,31 @@ std::pair BuildTrsmBatchedDescriptor( return {size, PackDescriptor(desc)}; } -void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const TrsmBatchedDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = BlasHandlePool::Borrow(stream); +absl::Status TrsmBatched_(cudaStream_t stream, void** buffers, + const char* opaque, size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const TrsmBatchedDescriptor& d = **s; + auto h = BlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; if (buffers[2] != buffers[1]) { - JAX_THROW_IF_ERROR(cudaMemcpyAsync(buffers[2], buffers[1], - SizeOfType(d.type) * d.batch * d.m * d.n, - cudaMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[2], buffers[1], SizeOfType(d.type) * d.batch * d.m * d.n, + cudaMemcpyDeviceToDevice, stream))); } const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n; const int ldb = d.m; auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch, SizeOfType(d.type) * lda * lda); + JAX_RETURN_IF_ERROR(a_batch_host.status()); auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, SizeOfType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(b_batch_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - JAX_THROW_IF_ERROR(cudaStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream))); switch (d.type) { case Type::F32: { float* a = static_cast(buffers[0]); @@ -150,10 +157,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, float** b_batch_ptrs = static_cast(buffers[4]); // NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault. const float alpha = 1.0f; - JAX_THROW_IF_ERROR(cublasStrsmBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasStrsmBatched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch)); + d.batch))); break; } case Type::F64: { @@ -162,10 +169,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, double** a_batch_ptrs = static_cast(buffers[3]); double** b_batch_ptrs = static_cast(buffers[4]); const double alpha = 1.0; - JAX_THROW_IF_ERROR(cublasDtrsmBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDtrsmBatched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch)); + d.batch))); break; } case Type::C64: { @@ -174,10 +181,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, cuComplex** a_batch_ptrs = static_cast(buffers[3]); cuComplex** b_batch_ptrs = static_cast(buffers[4]); const cuComplex alpha = make_cuComplex(1.0f, 0.0f); - JAX_THROW_IF_ERROR(cublasCtrsmBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCtrsmBatched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch)); + d.batch))); break; } case Type::C128: { @@ -188,13 +195,23 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, cuDoubleComplex** b_batch_ptrs = static_cast(buffers[4]); const cuDoubleComplex alpha = make_cuDoubleComplex(1.0f, 0.0f); - JAX_THROW_IF_ERROR(cublasZtrsmBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZtrsmBatched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, b_batch_ptrs, - ldb, d.batch)); + ldb, d.batch))); break; } } + return absl::OkStatus(); +} + +void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = TrsmBatched_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // Batched LU decomposition: getrfbatched @@ -212,55 +229,69 @@ std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; } -void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GetrfBatchedDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = BlasHandlePool::Borrow(stream); +absl::Status GetrfBatched_(cudaStream_t stream, void** buffers, + const char* opaque, size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GetrfBatchedDescriptor& d = **s; + auto h = BlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; if (buffers[0] != buffers[1]) { - JAX_THROW_IF_ERROR(cudaMemcpyAsync(buffers[1], buffers[0], - SizeOfType(d.type) * d.batch * d.n * d.n, - cudaMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n, + cudaMemcpyDeviceToDevice, stream))); } int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch, SizeOfType(d.type) * d.n * d.n); + JAX_RETURN_IF_ERROR(a_ptrs_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - JAX_THROW_IF_ERROR(cudaStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream))); switch (d.type) { case Type::F32: { float* a = static_cast(buffers[1]); float** batch_ptrs = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cublasSgetrfBatched(handle.get(), d.n, batch_ptrs, d.n, - ipiv, info, d.batch)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSgetrfBatched( + handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } case Type::F64: { double* a = static_cast(buffers[1]); double** batch_ptrs = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cublasDgetrfBatched(handle.get(), d.n, batch_ptrs, d.n, - ipiv, info, d.batch)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDgetrfBatched( + handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } case Type::C64: { cuComplex* a = static_cast(buffers[1]); cuComplex** batch_ptrs = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cublasCgetrfBatched(handle.get(), d.n, batch_ptrs, d.n, - ipiv, info, d.batch)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCgetrfBatched( + handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } case Type::C128: { cuDoubleComplex* a = static_cast(buffers[1]); cuDoubleComplex** batch_ptrs = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cublasZgetrfBatched(handle.get(), d.n, batch_ptrs, d.n, - ipiv, info, d.batch)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZgetrfBatched( + handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } } + return absl::OkStatus(); +} + +void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } py::dict Registrations() { diff --git a/jaxlib/cuda_gpu_kernel_helpers.cc b/jaxlib/cuda_gpu_kernel_helpers.cc index 89cdb213a..8d32e4bd4 100644 --- a/jaxlib/cuda_gpu_kernel_helpers.cc +++ b/jaxlib/cuda_gpu_kernel_helpers.cc @@ -23,15 +23,13 @@ limitations under the License. namespace jax { namespace { -std::string ErrorToString(cudaError_t error) { - return cudaGetErrorString(error); -} +std::string ErrorString(cudaError_t error) { return cudaGetErrorString(error); } -std::string ErrorToString(cusparseStatus_t status) { +std::string ErrorString(cusparseStatus_t status) { return cusparseGetErrorString(status); } -std::string ErrorToString(cusolverStatus_t status) { +std::string ErrorString(cusolverStatus_t status) { switch (status) { case CUSOLVER_STATUS_SUCCESS: return "cuSolver success."; @@ -62,7 +60,7 @@ std::string ErrorToString(cusolverStatus_t status) { } } -std::string ErrorToString(cublasStatus_t status) { +std::string ErrorString(cublasStatus_t status) { switch (status) { case CUBLAS_STATUS_SUCCESS: return "cuBlas success"; @@ -90,47 +88,53 @@ std::string ErrorToString(cublasStatus_t status) { } template -void ThrowError(T status, const char* file, std::int64_t line, - const char* expr) { - throw std::runtime_error(absl::StrFormat("%s:%d: operation %s failed: %s", - file, line, expr, - ErrorToString(status))); +std::string ErrorString(T status, const char* file, std::int64_t line, + const char* expr) { + return absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr, + ErrorString(status)); } } // namespace -void ThrowIfError(cudaError_t error, const char* file, std::int64_t line, - const char* expr) { - if (error != cudaSuccess) ThrowError(error, file, line, expr); +absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line, + const char* expr) { + if (error != cudaSuccess) + return absl::InternalError(ErrorString(error, file, line, expr)); + return absl::OkStatus(); } -void ThrowIfError(cusolverStatus_t status, const char* file, std::int64_t line, - const char* expr) { - if (status != CUSOLVER_STATUS_SUCCESS) ThrowError(status, file, line, expr); +absl::Status AsStatus(cusolverStatus_t status, const char* file, + std::int64_t line, const char* expr) { + if (status != CUSOLVER_STATUS_SUCCESS) + return absl::InternalError(ErrorString(status, file, line, expr)); + return absl::OkStatus(); } -void ThrowIfError(cusparseStatus_t status, const char* file, std::int64_t line, - const char* expr) { - if (status != CUSPARSE_STATUS_SUCCESS) ThrowError(status, file, line, expr); +absl::Status AsStatus(cusparseStatus_t status, const char* file, + std::int64_t line, const char* expr) { + if (status != CUSPARSE_STATUS_SUCCESS) + return absl::InternalError(ErrorString(status, file, line, expr)); + return absl::OkStatus(); } -void ThrowIfError(cublasStatus_t status, const char* file, std::int64_t line, - const char* expr) { - if (status != CUBLAS_STATUS_SUCCESS) ThrowError(status, file, line, expr); +absl::Status AsStatus(cublasStatus_t status, const char* file, + std::int64_t line, const char* expr) { + if (status != CUBLAS_STATUS_SUCCESS) + return absl::InternalError(ErrorString(status, file, line, expr)); + return absl::OkStatus(); } -std::unique_ptr MakeBatchPointers(cudaStream_t stream, void* buffer, - void* dev_ptrs, int batch, - int batch_elem_size) { +absl::StatusOr> MakeBatchPointers( + cudaStream_t stream, void* buffer, void* dev_ptrs, int batch, + int batch_elem_size) { char* ptr = static_cast(buffer); auto host_ptrs = absl::make_unique(batch); for (int i = 0; i < batch; ++i) { host_ptrs[i] = ptr; ptr += batch_elem_size; } - JAX_THROW_IF_ERROR(cudaMemcpyAsync(dev_ptrs, host_ptrs.get(), - sizeof(void*) * batch, - cudaMemcpyHostToDevice, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cudaMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, + cudaMemcpyHostToDevice, stream))); return host_ptrs; } } // namespace jax - diff --git a/jaxlib/cuda_gpu_kernel_helpers.h b/jaxlib/cuda_gpu_kernel_helpers.h index 1b45dc2af..0294928f5 100644 --- a/jaxlib/cuda_gpu_kernel_helpers.h +++ b/jaxlib/cuda_gpu_kernel_helpers.h @@ -18,33 +18,48 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusparse.h" -#define JAX_THROW_IF_ERROR(expr) \ - jax::ThrowIfError(expr, __FILE__, __LINE__, #expr) +#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr) + +#define JAX_THROW_IF_ERROR(expr) \ + { \ + auto s___ = (expr); \ + if (!s___.ok()) throw std::runtime_error(s___.error_message()); \ + } + +#define JAX_RETURN_IF_ERROR(expr) \ + { \ + auto s___ = (expr); \ + if (!s___.ok()) return s___; \ + } namespace jax { -// Used via JAX_THROW_IF_ERROR(expr) macro. -void ThrowIfError(cudaError_t error, const char* file, std::int64_t line, - const char* expr); -void ThrowIfError(cusolverStatus_t status, const char* file, std::int64_t line, - const char* expr); -void ThrowIfError(cusparseStatus_t status, const char* file, std::int64_t line, - const char* expr); -void ThrowIfError(cublasStatus_t status, const char* file, std::int64_t line, - const char* expr); +// Used via JAX_AS_STATUS(expr) macro. +absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line, + const char* expr); +absl::Status AsStatus(cusolverStatus_t status, const char* file, + std::int64_t line, const char* expr); +absl::Status AsStatus(cusparseStatus_t status, const char* file, + std::int64_t line, const char* expr); +absl::Status AsStatus(cublasStatus_t status, const char* file, + std::int64_t line, const char* expr); // Builds an array of pointers to each array in a batch, in device memory. // Caution: the return value must be kept alive (e.g., via a stream // synchronization) until the copy enqueued by MakeBatchPointers on `stream` // completes. -std::unique_ptr MakeBatchPointers(cudaStream_t stream, void* buffer, - void* dev_ptrs, int batch, - int batch_elem_size); +absl::StatusOr> MakeBatchPointers(cudaStream_t stream, + void* buffer, + void* dev_ptrs, + int batch, + int batch_elem_size); } // namespace jax diff --git a/jaxlib/cuda_linalg.py b/jaxlib/cuda_linalg.py index ec470452d..22c0f6f95 100644 --- a/jaxlib/cuda_linalg.py +++ b/jaxlib/cuda_linalg.py @@ -58,4 +58,6 @@ def lu_pivots_to_permutation(c, pivots, *, permutation_size): operands=(pivots,), shape_with_layout=permutations_shape_with_layout, operand_shapes_with_layout=(pivots_shape_with_layout,), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) diff --git a/jaxlib/cuda_lu_pivot_kernels.cu.cc b/jaxlib/cuda_lu_pivot_kernels.cu.cc index f50c8951b..0e28ea6b7 100644 --- a/jaxlib/cuda_lu_pivot_kernels.cu.cc +++ b/jaxlib/cuda_lu_pivot_kernels.cu.cc @@ -20,6 +20,7 @@ limitations under the License. #include "jaxlib/cuda_gpu_kernel_helpers.h" #include "jaxlib/kernel_helpers.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { namespace { @@ -73,13 +74,16 @@ std::string BuildCudaLuPivotsToPermutationDescriptor( batch_size, pivot_size, permutation_size}); } -void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { +absl::Status CudaLuPivotsToPermutation_(cudaStream_t stream, void** buffers, + const char* opaque, + std::size_t opaque_len) { const std::int32_t* pivots = reinterpret_cast(buffers[0]); std::int32_t* permutation_out = reinterpret_cast(buffers[1]); - const auto& descriptor = - *UnpackDescriptor(opaque, opaque_len); + auto s = + UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const auto& descriptor = **s; const int block_dim = 128; const std::int64_t grid_dim = std::min( @@ -89,7 +93,18 @@ void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers, /*dynamic_shared_mem_bytes=*/0, stream>>>( pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size, descriptor.permutation_size); - JAX_THROW_IF_ERROR(cudaGetLastError()); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError())); + return absl::OkStatus(); +} + +void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } } // namespace jax diff --git a/jaxlib/cuda_lu_pivot_kernels.h b/jaxlib/cuda_lu_pivot_kernels.h index 10bb086ae..84f3e0b08 100644 --- a/jaxlib/cuda_lu_pivot_kernels.h +++ b/jaxlib/cuda_lu_pivot_kernels.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { @@ -28,7 +29,8 @@ std::string BuildCudaLuPivotsToPermutationDescriptor( std::int32_t permutation_size); void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len); + const char* opaque, std::size_t opaque_len, + XlaCustomCallStatus* status); } // namespace jax diff --git a/jaxlib/cuda_prng.py b/jaxlib/cuda_prng.py index 09ece3fd2..d7ed088e9 100644 --- a/jaxlib/cuda_prng.py +++ b/jaxlib/cuda_prng.py @@ -56,4 +56,6 @@ def threefry2x32(c, keys, data): operands=(keys[0], keys[1], data[0], data[1]), shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]), operand_shapes_with_layout=(shape,) * 4, - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) diff --git a/jaxlib/cuda_prng_kernels.cu.cc b/jaxlib/cuda_prng_kernels.cu.cc index 5f452d23a..26b1daeca 100644 --- a/jaxlib/cuda_prng_kernels.cu.cc +++ b/jaxlib/cuda_prng_kernels.cu.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "jaxlib/cuda_prng_kernels.h" + #include #include -#include "jaxlib/cuda_prng_kernels.h" #include "jaxlib/cuda_gpu_kernel_helpers.h" #include "jaxlib/kernel_helpers.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { namespace { @@ -106,8 +108,8 @@ std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) { return PackDescriptorAsString(ThreeFry2x32Descriptor{n}); } -void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { +absl::Status CudaThreeFry2x32_(cudaStream_t stream, void** buffers, + const char* opaque, std::size_t opaque_len) { std::array keys; keys[0] = reinterpret_cast(buffers[0]); keys[1] = reinterpret_cast(buffers[1]); @@ -117,15 +119,26 @@ void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, std::array out; out[0] = reinterpret_cast(buffers[4]); out[1] = reinterpret_cast(buffers[5]); - const auto& descriptor = - *UnpackDescriptor(opaque, opaque_len); + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const auto& descriptor = **s; const int block_dim = 128; const std::int64_t grid_dim = std::min(1024, (descriptor.n + block_dim - 1) / block_dim); ThreeFry2x32Kernel<<>>(keys[0], keys[1], data[0], data[1], out[0], out[1], descriptor.n); - JAX_THROW_IF_ERROR(cudaGetLastError()); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError())); + return absl::OkStatus(); +} + +void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } } // namespace jax diff --git a/jaxlib/cuda_prng_kernels.h b/jaxlib/cuda_prng_kernels.h index 6512bee59..19a48c4e5 100644 --- a/jaxlib/cuda_prng_kernels.h +++ b/jaxlib/cuda_prng_kernels.h @@ -20,13 +20,14 @@ limitations under the License. #include #include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n); void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len); + std::size_t opaque_len, XlaCustomCallStatus* status); } // namespace jax diff --git a/jaxlib/cusolver.cc b/jaxlib/cusolver.cc index ce25e15f8..f459fc763 100644 --- a/jaxlib/cusolver.cc +++ b/jaxlib/cusolver.cc @@ -34,6 +34,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cusolverDn.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { namespace { @@ -42,19 +43,19 @@ namespace py = pybind11; using SolverHandlePool = HandlePool; template <> -/*static*/ SolverHandlePool::Handle SolverHandlePool::Borrow( +/*static*/ absl::StatusOr SolverHandlePool::Borrow( cudaStream_t stream) { SolverHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); cusolverDnHandle_t handle; if (pool->handles_[stream].empty()) { - JAX_THROW_IF_ERROR(cusolverDnCreate(&handle)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreate(&handle))); } else { handle = pool->handles_[stream].back(); pool->handles_[stream].pop_back(); } if (stream) { - JAX_THROW_IF_ERROR(cusolverDnSetStream(handle, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSetStream(handle, stream))); } return Handle(pool, handle, stream); } @@ -109,7 +110,9 @@ struct PotrfDescriptor { std::pair BuildPotrfDescriptor(const py::dtype& dtype, bool lower, int b, int n) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; std::int64_t workspace_size; cublasFillMode_t uplo = @@ -117,27 +120,31 @@ std::pair BuildPotrfDescriptor(const py::dtype& dtype, if (b == 1) { switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnSpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); workspace_size = lwork * sizeof(float); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnDpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); workspace_size = lwork * sizeof(double); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnCpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); workspace_size = lwork * sizeof(cuComplex); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnZpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); workspace_size = lwork * sizeof(cuDoubleComplex); break; } @@ -149,15 +156,18 @@ std::pair BuildPotrfDescriptor(const py::dtype& dtype, PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})}; } -void Potrf(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const PotrfDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); +absl::Status Potrf_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const PotrfDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_THROW_IF_ERROR(cudaMemcpyAsync(buffers[1], buffers[0], - SizeOfType(d.type) * d.batch * d.n * d.n, - cudaMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n, + cudaMemcpyDeviceToDevice, stream))); } int* info = static_cast(buffers[2]); @@ -166,67 +176,78 @@ void Potrf(cudaStream_t stream, void** buffers, const char* opaque, switch (d.type) { case Type::F32: { float* a = static_cast(buffers[1]); - JAX_THROW_IF_ERROR(cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, + static_cast(workspace), d.lwork, info))); break; } case Type::F64: { double* a = static_cast(buffers[1]); - JAX_THROW_IF_ERROR(cusolverDnDpotrf(handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnDpotrf(handle.get(), d.uplo, d.n, a, d.n, + static_cast(workspace), d.lwork, info))); break; } case Type::C64: { cuComplex* a = static_cast(buffers[1]); - JAX_THROW_IF_ERROR(cusolverDnCpotrf(handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCpotrf( + handle.get(), d.uplo, d.n, a, d.n, + static_cast(workspace), d.lwork, info))); break; } case Type::C128: { cuDoubleComplex* a = static_cast(buffers[1]); - JAX_THROW_IF_ERROR(cusolverDnZpotrf( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZpotrf( handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info)); + static_cast(workspace), d.lwork, info))); break; } } } else { auto buffer_ptrs_host = MakeBatchPointers( stream, buffers[1], workspace, d.batch, SizeOfType(d.type) * d.n * d.n); + JAX_RETURN_IF_ERROR(buffer_ptrs_host.status()); // Make sure that accesses to buffer_ptrs_host complete before we delete it. // TODO(phawkins): avoid synchronization here. - JAX_THROW_IF_ERROR(cudaStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream))); switch (d.type) { case Type::F32: { - JAX_THROW_IF_ERROR(cusolverDnSpotrfBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - info, d.batch)); + info, d.batch))); break; } case Type::F64: { - JAX_THROW_IF_ERROR(cusolverDnDpotrfBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - info, d.batch)); + info, d.batch))); break; } case Type::C64: { - JAX_THROW_IF_ERROR(cusolverDnCpotrfBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - info, d.batch)); + info, d.batch))); break; } case Type::C128: { - JAX_THROW_IF_ERROR(cusolverDnZpotrfBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZpotrfBatched( handle.get(), d.uplo, d.n, - static_cast(workspace), d.n, info, d.batch)); + static_cast(workspace), d.n, info, d.batch))); break; } } } + return absl::OkStatus(); +} + +void Potrf(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Potrf_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // getrf: LU decomposition @@ -240,44 +261,53 @@ struct GetrfDescriptor { std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, int m, int n) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnSgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnDgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnCgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnZgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; } return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})}; } -void Getrf(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GetrfDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); +absl::Status Getrf_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GetrfDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_THROW_IF_ERROR(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( buffers[1], buffers[0], SizeOfType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream))); } int* ipiv = static_cast(buffers[2]); @@ -287,9 +317,9 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque, case Type::F32: { float* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnSgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), - ipiv, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnSgetrf(handle.get(), d.m, d.n, a, d.m, + static_cast(workspace), ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; @@ -299,9 +329,9 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque, case Type::F64: { double* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnDgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), - ipiv, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnDgetrf(handle.get(), d.m, d.n, a, d.m, + static_cast(workspace), ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; @@ -311,9 +341,9 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque, case Type::C64: { cuComplex* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnCgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), - ipiv, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnCgetrf(handle.get(), d.m, d.n, a, d.m, + static_cast(workspace), ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; @@ -323,9 +353,9 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque, case Type::C128: { cuDoubleComplex* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnZgetrf( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgetrf( handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), ipiv, info)); + static_cast(workspace), ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; @@ -333,6 +363,16 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque, break; } } + return absl::OkStatus(); +} + +void Getrf(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Getrf_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // geqrf: QR decomposition @@ -346,44 +386,53 @@ struct GeqrfDescriptor { std::pair BuildGeqrfDescriptor(const py::dtype& dtype, int b, int m, int n) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnSgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnDgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnCgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnZgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; } return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; } -void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GeqrfDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); +absl::Status Geqrf_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GeqrfDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_THROW_IF_ERROR(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( buffers[1], buffers[0], SizeOfType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream))); } int* info = static_cast(buffers[3]); @@ -393,9 +442,9 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, float* a = static_cast(buffers[1]); float* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; @@ -406,9 +455,9 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, double* a = static_cast(buffers[1]); double* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; @@ -419,9 +468,9 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, cuComplex* a = static_cast(buffers[1]); cuComplex* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnCgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgeqrf( + handle.get(), d.m, d.n, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; @@ -432,9 +481,9 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, cuDoubleComplex* a = static_cast(buffers[1]); cuDoubleComplex* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnZgeqrf( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgeqrf( handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info)); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; @@ -442,6 +491,16 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, break; } } + return absl::OkStatus(); +} + +void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Geqrf_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // orgqr/ungqr: apply elementary Householder transformations @@ -455,48 +514,57 @@ struct OrgqrDescriptor { std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, int m, int n, int k) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, /*tau=*/nullptr, - &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnSorgqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, /*tau=*/nullptr, - &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnDorgqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, /*tau=*/nullptr, - &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnCungqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, /*tau=*/nullptr, - &lwork)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusolverDnZungqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; } return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; } -void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const OrgqrDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); +absl::Status Orgqr_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const OrgqrDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; if (buffers[2] != buffers[0]) { - JAX_THROW_IF_ERROR(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( buffers[2], buffers[0], SizeOfType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream))); } int* info = static_cast(buffers[3]); @@ -506,9 +574,9 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, float* a = static_cast(buffers[2]); float* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, - tau, static_cast(workspace), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; @@ -519,9 +587,9 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, double* a = static_cast(buffers[2]); double* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info)); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; @@ -532,9 +600,9 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, cuComplex* a = static_cast(buffers[2]); cuComplex* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnCungqr( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCungqr( handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info)); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; @@ -545,9 +613,9 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, cuDoubleComplex* a = static_cast(buffers[2]); cuDoubleComplex* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnZungqr( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZungqr( handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info)); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; @@ -555,6 +623,16 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, break; } } + return absl::OkStatus(); +} + +void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Orgqr_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd @@ -570,46 +648,51 @@ struct SyevdDescriptor { std::pair BuildSyevdDescriptor(const py::dtype& dtype, bool lower, int b, int n) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSsyevd_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork)); + &lwork))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDsyevd_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork)); + &lwork))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCheevd_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork)); + &lwork))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZheevd_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork)); + &lwork))); break; } return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})}; } -void Syevd(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const SyevdDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); - JAX_THROW_IF_ERROR(cudaMemcpyAsync( +absl::Status Syevd_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const SyevdDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( buffers[1], buffers[0], SizeOfType(d.type) * static_cast(d.batch) * static_cast(d.n) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream))); cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; int* info = static_cast(buffers[3]); void* work = buffers[4]; @@ -618,9 +701,9 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque, float* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, - d.n, w, static_cast(work), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; @@ -631,9 +714,9 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque, double* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, - d.n, w, static_cast(work), - d.lwork, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; @@ -644,9 +727,9 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque, cuComplex* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info)); + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; @@ -657,9 +740,9 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque, cuDoubleComplex* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnZheevd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevd( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info)); + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; @@ -667,6 +750,16 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque, break; } } + return absl::OkStatus(); +} + +void Syevd(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Syevd_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj @@ -683,10 +776,12 @@ struct SyevjDescriptor { std::pair BuildSyevjDescriptor(const py::dtype& dtype, bool lower, int batch, int n) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; syevjInfo_t params; - JAX_THROW_IF_ERROR(cusolverDnCreateSyevjInfo(¶ms)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(¶ms))); std::unique_ptr params_cleanup( params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); }); cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; @@ -695,67 +790,70 @@ std::pair BuildSyevjDescriptor(const py::dtype& dtype, if (batch == 1) { switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSsyevj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params)); + /*W=*/nullptr, &lwork, params))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDsyevj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params)); + /*W=*/nullptr, &lwork, params))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCheevj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params)); + /*W=*/nullptr, &lwork, params))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZheevj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params)); + /*W=*/nullptr, &lwork, params))); break; } } else { switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSsyevjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch)); + /*W=*/nullptr, &lwork, params, batch))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDsyevjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch)); + /*W=*/nullptr, &lwork, params, batch))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCheevjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch)); + /*W=*/nullptr, &lwork, params, batch))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZheevjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch)); + /*W=*/nullptr, &lwork, params, batch))); break; } } return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})}; } -void Syevj(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const SyevjDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); +absl::Status Syevj_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const SyevjDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_THROW_IF_ERROR(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( buffers[1], buffers[0], SizeOfType(d.type) * static_cast(d.batch) * static_cast(d.n) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream))); } syevjInfo_t params; - JAX_THROW_IF_ERROR(cusolverDnCreateSyevjInfo(¶ms)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(¶ms))); std::unique_ptr params_cleanup( params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); }); @@ -767,33 +865,33 @@ void Syevj(cudaStream_t stream, void** buffers, const char* opaque, case Type::F32: { float* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR(cusolverDnSsyevj(handle.get(), jobz, d.uplo, d.n, a, - d.n, w, static_cast(work), - d.lwork, info, params)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevj( + handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info, params))); break; } case Type::F64: { double* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR(cusolverDnDsyevj(handle.get(), jobz, d.uplo, d.n, a, - d.n, w, static_cast(work), - d.lwork, info, params)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevj( + handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info, params))); break; } case Type::C64: { cuComplex* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR(cusolverDnCheevj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevj( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params)); + static_cast(work), d.lwork, info, params))); break; } case Type::C128: { cuDoubleComplex* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR(cusolverDnZheevj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevj( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params)); + static_cast(work), d.lwork, info, params))); break; } } @@ -802,38 +900,48 @@ void Syevj(cudaStream_t stream, void** buffers, const char* opaque, case Type::F32: { float* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR(cusolverDnSsyevjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevjBatched( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch)); + static_cast(work), d.lwork, info, params, d.batch))); break; } case Type::F64: { double* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR(cusolverDnDsyevjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevjBatched( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch)); + static_cast(work), d.lwork, info, params, d.batch))); break; } case Type::C64: { cuComplex* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR(cusolverDnCheevjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevjBatched( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch)); + static_cast(work), d.lwork, info, params, d.batch))); break; } case Type::C128: { cuDoubleComplex* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); - JAX_THROW_IF_ERROR( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast(work), - d.lwork, info, params, d.batch)); + d.lwork, info, params, d.batch))); break; } } } + return absl::OkStatus(); +} + +void Syevj(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Syevj_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // Singular value decomposition using QR algorithm: gesvd @@ -850,24 +958,26 @@ std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, int m, int n, bool compute_uv, bool full_matrices) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; switch (type) { case Type::F32: - JAX_THROW_IF_ERROR( - cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork))); break; case Type::F64: - JAX_THROW_IF_ERROR( - cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork))); break; case Type::C64: - JAX_THROW_IF_ERROR( - cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork))); break; case Type::C128: - JAX_THROW_IF_ERROR( - cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork))); break; } signed char jobu, jobvt; @@ -884,16 +994,19 @@ std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})}; } -void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GesvdDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); - JAX_THROW_IF_ERROR(cudaMemcpyAsync( +absl::Status Gesvd_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GesvdDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( buffers[1], buffers[0], SizeOfType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream))); int* info = static_cast(buffers[5]); void* work = buffers[6]; switch (d.type) { @@ -903,10 +1016,10 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, float* u = static_cast(buffers[3]); float* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnSgesvd(handle.get(), d.jobu, d.jobvt, d.m, - d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvd( + handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, + static_cast(work), d.lwork, + /*rwork=*/nullptr, info))); a += d.m * d.n; s += std::min(d.m, d.n); u += d.m * d.m; @@ -921,10 +1034,10 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, double* u = static_cast(buffers[3]); double* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnDgesvd(handle.get(), d.jobu, d.jobvt, d.m, - d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvd( + handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, + static_cast(work), d.lwork, + /*rwork=*/nullptr, info))); a += d.m * d.n; s += std::min(d.m, d.n); u += d.m * d.m; @@ -939,9 +1052,9 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, cuComplex* u = static_cast(buffers[3]); cuComplex* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnCgesvd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvd( handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, /*rwork=*/nullptr, info)); + static_cast(work), d.lwork, /*rwork=*/nullptr, info))); a += d.m * d.n; s += std::min(d.m, d.n); u += d.m * d.m; @@ -956,10 +1069,10 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, cuDoubleComplex* u = static_cast(buffers[3]); cuDoubleComplex* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_THROW_IF_ERROR(cusolverDnZgesvd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvd( handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, static_cast(work), d.lwork, - /*rwork=*/nullptr, info)); + /*rwork=*/nullptr, info))); a += d.m * d.n; s += std::min(d.m, d.n); u += d.m * d.m; @@ -969,6 +1082,16 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, break; } } + return absl::OkStatus(); +} + +void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Gesvd_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // Singular value decomposition using Jacobi algorithm: gesvdj @@ -985,74 +1108,76 @@ std::pair BuildGesvdjDescriptor(const py::dtype& dtype, int batch, int m, int n, bool compute_uv) { Type type = DtypeToType(dtype); - auto handle = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; int lwork; cusolverEigMode_t jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; gesvdjInfo_t params; - JAX_THROW_IF_ERROR(cusolverDnCreateGesvdjInfo(¶ms)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); std::unique_ptr params_cleanup( params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); if (batch == 1) { switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSgesvdj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize( handle.get(), jobz, /*econ=*/0, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params)); + /*ldv=*/n, &lwork, params))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDgesvdj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize( handle.get(), jobz, /*econ=*/0, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params)); + /*ldv=*/n, &lwork, params))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCgesvdj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize( handle.get(), jobz, /*econ=*/0, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params)); + /*ldv=*/n, &lwork, params))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZgesvdj_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize( handle.get(), jobz, /*econ=*/0, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params)); + /*ldv=*/n, &lwork, params))); break; } } else { switch (type) { case Type::F32: - JAX_THROW_IF_ERROR(cusolverDnSgesvdjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch)); + /*ldv=*/n, &lwork, params, batch))); break; case Type::F64: - JAX_THROW_IF_ERROR(cusolverDnDgesvdjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch)); + /*ldv=*/n, &lwork, params, batch))); break; case Type::C64: - JAX_THROW_IF_ERROR(cusolverDnCgesvdjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch)); + /*ldv=*/n, &lwork, params, batch))); break; case Type::C128: - JAX_THROW_IF_ERROR(cusolverDnZgesvdjBatched_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch)); + /*ldv=*/n, &lwork, params, batch))); break; } } @@ -1060,20 +1185,23 @@ std::pair BuildGesvdjDescriptor(const py::dtype& dtype, PackDescriptor(GesvdjDescriptor{type, batch, m, n, lwork, jobz})}; } -void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GesvdjDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SolverHandlePool::Borrow(stream); - JAX_THROW_IF_ERROR(cudaMemcpyAsync( +absl::Status Gesvdj_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GesvdjDescriptor& d = **s; + auto h = SolverHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( buffers[1], buffers[0], SizeOfType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream)); + cudaMemcpyDeviceToDevice, stream))); int* info = static_cast(buffers[5]); void* work = buffers[6]; gesvdjInfo_t params; - JAX_THROW_IF_ERROR(cusolverDnCreateGesvdjInfo(¶ms)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); std::unique_ptr params_cleanup( params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); if (d.batch == 1) { @@ -1083,9 +1211,9 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, float* s = static_cast(buffers[2]); float* u = static_cast(buffers[3]); float* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnSgesvdj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj( handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params)); + d.n, static_cast(work), d.lwork, info, params))); break; } case Type::F64: { @@ -1093,9 +1221,9 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, double* s = static_cast(buffers[2]); double* u = static_cast(buffers[3]); double* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnDgesvdj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj( handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params)); + d.n, static_cast(work), d.lwork, info, params))); break; } case Type::C64: { @@ -1103,9 +1231,9 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, float* s = static_cast(buffers[2]); cuComplex* u = static_cast(buffers[3]); cuComplex* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnCgesvdj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj( handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params)); + d.n, static_cast(work), d.lwork, info, params))); break; } case Type::C128: { @@ -1113,9 +1241,9 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, double* s = static_cast(buffers[2]); cuDoubleComplex* u = static_cast(buffers[3]); cuDoubleComplex* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnZgesvdj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj( handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params)); + d.n, static_cast(work), d.lwork, info, params))); break; } } @@ -1126,9 +1254,9 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, float* s = static_cast(buffers[2]); float* u = static_cast(buffers[3]); float* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnSgesvdjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched( handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch)); + static_cast(work), d.lwork, info, params, d.batch))); break; } case Type::F64: { @@ -1136,9 +1264,9 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, double* s = static_cast(buffers[2]); double* u = static_cast(buffers[3]); double* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnDgesvdjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched( handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch)); + static_cast(work), d.lwork, info, params, d.batch))); break; } case Type::C64: { @@ -1146,9 +1274,9 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, float* s = static_cast(buffers[2]); cuComplex* u = static_cast(buffers[3]); cuComplex* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnCgesvdjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched( handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch)); + static_cast(work), d.lwork, info, params, d.batch))); break; } case Type::C128: { @@ -1156,14 +1284,24 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, double* s = static_cast(buffers[2]); cuDoubleComplex* u = static_cast(buffers[3]); cuDoubleComplex* v = static_cast(buffers[4]); - JAX_THROW_IF_ERROR(cusolverDnZgesvdjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched( handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, static_cast(work), d.lwork, info, params, - d.batch)); + d.batch))); break; } } } + return absl::OkStatus(); +} + +void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Gesvdj_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } py::dict Registrations() { diff --git a/jaxlib/cusolver.py b/jaxlib/cusolver.py index 21ff17bf7..9265c86a5 100644 --- a/jaxlib/cusolver.py +++ b/jaxlib/cusolver.py @@ -97,7 +97,9 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False, _Shape.array_shape(dtype, a_shape.dimensions(), layout), _Shape.array_shape(dtype, b_shape.dimensions(), layout), ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return _ops.GetTupleElement(out, 0) @@ -131,7 +133,9 @@ def potrf(c, a, lower): operand_shapes_with_layout=(_Shape.array_shape( dtype, batch_dims + (n, n), (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1) @@ -175,7 +179,9 @@ def getrf(c, a): operand_shapes_with_layout=(_Shape.array_shape( dtype, batch_dims + (m, n), (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2)) @@ -213,7 +219,9 @@ def geqrf(c, a): operand_shapes_with_layout=(_Shape.array_shape( dtype, batch_dims + (m, n), (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2)) @@ -257,7 +265,9 @@ def orgqr(c, a, tau): dtype, batch_dims + (k,), tuple(range(num_bd, -1, -1))), ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)) @@ -302,7 +312,9 @@ def syevd(c, a, lower=False): operand_shapes_with_layout=( _Shape.array_shape(dtype, dims, layout), ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2)) @@ -342,7 +354,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True): operand_shapes_with_layout=( _Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout), ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) s = _ops.GetTupleElement(out, 1) u = _ops.GetTupleElement(out, 2) v = _ops.GetTupleElement(out, 3) @@ -371,7 +385,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True): operand_shapes_with_layout=( _Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout), ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) s = _ops.GetTupleElement(out, 1) vt = _ops.GetTupleElement(out, 2) u = _ops.GetTupleElement(out, 3) @@ -398,7 +414,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True): operand_shapes_with_layout=( _Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout), ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) s = _ops.GetTupleElement(out, 1) u = _ops.GetTupleElement(out, 2) vt = _ops.GetTupleElement(out, 3) diff --git a/jaxlib/cusparse.cc b/jaxlib/cusparse.cc index e1ea6f267..dc0776999 100644 --- a/jaxlib/cusparse.cc +++ b/jaxlib/cusparse.cc @@ -34,6 +34,7 @@ limitations under the License. #include "include/pybind11/numpy.h" #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" // Some functionality defined here is only available in CUSPARSE 11.3 or newer. #define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300) @@ -148,19 +149,19 @@ CudaConst CudaOne(cudaDataType type) { using SparseHandlePool = HandlePool; template <> -/*static*/ SparseHandlePool::Handle SparseHandlePool::Borrow( +/*static*/ absl::StatusOr SparseHandlePool::Borrow( cudaStream_t stream) { SparseHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); cusparseHandle_t handle; if (pool->handles_[stream].empty()) { - JAX_THROW_IF_ERROR(cusparseCreate(&handle)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreate(&handle))); } else { handle = pool->handles_[stream].back(); pool->handles_[stream].pop_back(); } if (stream) { - JAX_THROW_IF_ERROR(cusparseSetStream(handle, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSetStream(handle, stream))); } return Handle(pool, handle, stream); } @@ -246,7 +247,9 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype, std::pair BuildCsrToDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor d = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); @@ -258,47 +261,60 @@ std::pair BuildCsrToDenseDescriptor( int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, empty, - empty, empty, d.index_type, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, - CUSPARSE_ORDER_ROW)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( + &mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, + d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_b, d.rows, d.cols, + /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); size_t buffer_size; - JAX_THROW_IF_ERROR(cusparseSparseToDense_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize( handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT, - &buffer_size)); + &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } -void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const SparseMatDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CsrToDense_(cudaStream_t stream, void** buffers, + const char* opaque, size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const SparseMatDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; cusparseSpMatDescr_t mat_a = 0; cusparseDnMatDescr_t mat_b = 0; - JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[2], - /*csrColInd=*/buffers[1], - /*csrValues=*/buffers[0], d.index_type, - d.index_type, CUSPARSE_INDEX_BASE_ZERO, - d.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], - d.value_type, CUSPARSE_ORDER_ROW)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, + /*csrRowOffsets=*/buffers[2], + /*csrColInd=*/buffers[1], + /*csrValues=*/buffers[0], d.index_type, d.index_type, + CUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_b, d.rows, d.cols, + /*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(cusparseSparseToDense(handle.get(), mat_a, mat_b, - CUSPARSE_SPARSETODENSE_ALG_DEFAULT, - buffers[4])); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusparseSparseToDense(handle.get(), mat_a, mat_b, + CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); + return absl::OkStatus(); +} + +void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CsrToDense_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // CsrFromDense: Convert dense matrix to CSR matrix @@ -307,7 +323,9 @@ void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque, std::pair BuildCsrFromDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor d = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); @@ -317,48 +335,61 @@ std::pair BuildCsrFromDenseDescriptor( // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, - CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, empty, - empty, empty, d.index_type, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_a, d.rows, d.cols, + /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( + &mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, + d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); size_t buffer_size; - JAX_THROW_IF_ERROR(cusparseDenseToSparse_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize( handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - &buffer_size)); + &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } -void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const SparseMatDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CsrFromDense_(cudaStream_t stream, void** buffers, + const char* opaque, size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const SparseMatDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; cusparseDnMatDescr_t mat_a = 0; cusparseSpMatDescr_t mat_b = 0; - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], - d.value_type, CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[3], - /*csrColInd=*/buffers[2], - /*csrValues=*/buffers[1], d.index_type, - d.index_type, CUSPARSE_INDEX_BASE_ZERO, - d.value_type)); - JAX_THROW_IF_ERROR(cusparseDenseToSparse_analysis( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_a, d.rows, d.cols, + /*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, + /*csrRowOffsets=*/buffers[3], + /*csrColInd=*/buffers[2], + /*csrValues=*/buffers[1], d.index_type, d.index_type, + CUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis( handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4])); - JAX_THROW_IF_ERROR(cusparseDenseToSparse_convert( + buffers[4]))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert( handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4])); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b)); + buffers[4]))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); + return absl::OkStatus(); +} + +void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // CsrMatvec: Product of CSR matrix and dense vector. @@ -374,7 +405,9 @@ std::pair BuildCsrMatvecDescriptor( const py::dtype& data_dtype, const py::dtype& x_dtype, const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor A = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); DenseVecDescriptor x = @@ -391,30 +424,35 @@ std::pair BuildCsrMatvecDescriptor( // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty, - empty, empty, A.index_type, A.index_type, - CUSPARSE_INDEX_BASE_ZERO, A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( + &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, + A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; CudaConst alpha = CudaOne(y.type); CudaConst beta = CudaZero(y.type); - JAX_THROW_IF_ERROR(cusparseSpMV_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_MV_ALG_DEFAULT, &buffer_size)); + CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})}; } -void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const CsrMatvecDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CsrMatvec_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const CsrMatvecDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; void* csr_values = buffers[0]; void* csr_col_ind = buffers[1]; @@ -434,20 +472,32 @@ void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque, cusparseDnVecDescr_t vec_x = 0; cusparseDnVecDescr_t vec_y = 0; - JAX_THROW_IF_ERROR( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind, csr_values, d.A.index_type, d.A.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)); + CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); - JAX_THROW_IF_ERROR(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, - &beta, vec_y, d.y.type, - CUSPARSE_MV_ALG_DEFAULT, buf)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, + d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); + return absl::OkStatus(); +} + +void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CsrMatvec_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // CsrMatmat: Product of CSR matrix and dense matrix. @@ -463,7 +513,9 @@ std::pair BuildCsrMatmatDescriptor( const py::dtype& data_dtype, const py::dtype& b_dtype, const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, int cols, int BCcols, int nnz, bool transpose) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor A = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); DenseMatDescriptor B = @@ -480,32 +532,37 @@ std::pair BuildCsrMatmatDescriptor( // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty, - empty, empty, A.index_type, A.index_type, - CUSPARSE_INDEX_BASE_ZERO, A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, - empty, B.type, CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, - empty, C.type, CUSPARSE_ORDER_ROW)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( + &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, + A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, + empty, B.type, CUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, + empty, C.type, CUSPARSE_ORDER_ROW))); size_t buffer_size; CudaConst alpha = CudaOne(C.type); CudaConst beta = CudaZero(C.type); - JAX_THROW_IF_ERROR(cusparseSpMM_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize( handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, - mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)); + mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})}; } -void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const CsrMatmatDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CsrMatmat_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const CsrMatmatDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; void* csr_values = buffers[0]; void* csr_col_ind = buffers[1]; @@ -525,23 +582,33 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque, cusparseDnMatDescr_t mat_b = 0; cusparseDnMatDescr_t mat_c = 0; - JAX_THROW_IF_ERROR( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind, csr_values, d.A.index_type, d.A.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols, - /*ld=*/d.B.cols, Bbuf, d.B.type, - CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols, - /*ld=*/d.C.cols, Cbuf, d.C.type, - CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseSpMM( + CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_b, d.B.rows, d.B.cols, + /*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_c, d.C.rows, d.C.cols, + /*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM( handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, - mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)); + mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); + return absl::OkStatus(); +} + +void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CsrMatmat_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // CooToDense: Convert COO matrix to dense matrix @@ -550,7 +617,9 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque, std::pair BuildCooToDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor d = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); @@ -561,46 +630,60 @@ std::pair BuildCooToDenseDescriptor( int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, - empty, empty, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, - CUSPARSE_ORDER_ROW)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, + d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_b, d.rows, d.cols, + /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); size_t buffer_size; - JAX_THROW_IF_ERROR(cusparseSparseToDense_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize( handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT, - &buffer_size)); + &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } -void CooToDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const SparseMatDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CooToDense_(cudaStream_t stream, void** buffers, + const char* opaque, size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const SparseMatDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; cusparseSpMatDescr_t mat_a = 0; cusparseDnMatDescr_t mat_b = 0; - JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, - /*cooRowInd=*/buffers[1], - /*cooColInd=*/buffers[2], - /*cooValues=*/buffers[0], d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], - d.value_type, CUSPARSE_ORDER_ROW)); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, + /*cooRowInd=*/buffers[1], + /*cooColInd=*/buffers[2], + /*cooValues=*/buffers[0], d.index_type, + CUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_b, d.rows, d.cols, + /*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(cusparseSparseToDense(handle.get(), mat_a, mat_b, - CUSPARSE_SPARSETODENSE_ALG_DEFAULT, - buffers[4])); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusparseSparseToDense(handle.get(), mat_a, mat_b, + CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); + return absl::OkStatus(); +} + +void CooToDense(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CooToDense_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // CooFromDense: Convert dense matrix to COO matrix @@ -609,7 +692,9 @@ void CooToDense(cudaStream_t stream, void** buffers, const char* opaque, std::pair BuildCooFromDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor d = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); @@ -619,47 +704,61 @@ std::pair BuildCooFromDenseDescriptor( // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, - CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, - empty, empty, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_a, d.rows, d.cols, + /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, + d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); size_t buffer_size; - JAX_THROW_IF_ERROR(cusparseDenseToSparse_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize( handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - &buffer_size)); + &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } -void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const SparseMatDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CooFromDense_(cudaStream_t stream, void** buffers, + const char* opaque, size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const SparseMatDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; cusparseDnMatDescr_t mat_a = 0; cusparseSpMatDescr_t mat_b = 0; - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], - d.value_type, CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, - /*cooRowInd=*/buffers[2], - /*cooColInd=*/buffers[3], - /*cooValues=*/buffers[1], d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type)); - JAX_THROW_IF_ERROR(cusparseDenseToSparse_analysis( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_a, d.rows, d.cols, + /*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, + /*cooRowInd=*/buffers[2], + /*cooColInd=*/buffers[3], + /*cooValues=*/buffers[1], d.index_type, + CUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis( handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4])); - JAX_THROW_IF_ERROR(cusparseDenseToSparse_convert( + buffers[4]))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert( handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4])); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b)); + buffers[4]))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); + return absl::OkStatus(); +} + +void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CooFromDense_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // CooMatvec: Product of COO matrix and dense vector. @@ -675,7 +774,9 @@ std::pair BuildCooMatvecDescriptor( const py::dtype& data_dtype, const py::dtype& x_dtype, const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor A = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); DenseVecDescriptor x = @@ -692,30 +793,35 @@ std::pair BuildCooMatvecDescriptor( // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, - empty, empty, A.index_type, - CUSPARSE_INDEX_BASE_ZERO, A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, + A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; CudaConst alpha = CudaOne(y.type); CudaConst beta = CudaZero(y.type); - JAX_THROW_IF_ERROR(cusparseSpMV_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_MV_ALG_DEFAULT, &buffer_size)); + CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})}; } -void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const CooMatvecDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CooMatvec_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const CooMatvecDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; void* coo_values = buffers[0]; void* coo_row_ind = buffers[1]; @@ -735,19 +841,31 @@ void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque, cusparseDnVecDescr_t vec_x = 0; cusparseDnVecDescr_t vec_y = 0; - JAX_THROW_IF_ERROR(cusparseCreateCoo( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo( &mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values, - d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)); - JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)); + d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); - JAX_THROW_IF_ERROR(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, - &beta, vec_y, d.y.type, - CUSPARSE_MV_ALG_DEFAULT, buf)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, + d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x)); - JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); + return absl::OkStatus(); +} + +void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CooMatvec_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // CooMatmat: Product of COO matrix and dense matrix. @@ -763,7 +881,9 @@ std::pair BuildCooMatmatDescriptor( const py::dtype& data_dtype, const py::dtype& b_dtype, const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, int cols, int BCcols, int nnz, bool transpose) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; SparseMatDescriptor A = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz); DenseMatDescriptor B = @@ -780,32 +900,37 @@ std::pair BuildCooMatmatDescriptor( // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, - empty, empty, A.index_type, - CUSPARSE_INDEX_BASE_ZERO, A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, - empty, B.type, CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, - empty, C.type, CUSPARSE_ORDER_ROW)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, + A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, + empty, B.type, CUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, + empty, C.type, CUSPARSE_ORDER_ROW))); size_t buffer_size; CudaConst alpha = CudaOne(C.type); CudaConst beta = CudaZero(C.type); - JAX_THROW_IF_ERROR(cusparseSpMM_bufferSize( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize( handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, - mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)); + mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c)); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})}; } -void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const CooMatmatDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = SparseHandlePool::Borrow(stream); +absl::Status CooMatmat_(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const CooMatmatDescriptor& d = **s; + auto h = SparseHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; void* coo_values = buffers[0]; void* coo_row_ind = buffers[1]; @@ -825,22 +950,32 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque, cusparseDnMatDescr_t mat_b = 0; cusparseDnMatDescr_t mat_c = 0; - JAX_THROW_IF_ERROR(cusparseCreateCoo( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo( &mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values, - d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols, - /*ld=*/d.B.cols, Bbuf, d.B.type, - CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols, - /*ld=*/d.C.cols, Cbuf, d.C.type, - CUSPARSE_ORDER_ROW)); - JAX_THROW_IF_ERROR(cusparseSpMM( + d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_b, d.B.rows, d.B.cols, + /*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( + &mat_c, d.C.rows, d.C.cols, + /*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM( handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, - mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)); + mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf))); - JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b)); - JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); + return absl::OkStatus(); +} + +void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = CooMatmat_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } #endif // if JAX_CUSPARSE_11030 @@ -853,12 +988,15 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) { } template -void gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { - auto handle = SparseHandlePool::Borrow(); +absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers, + const char* opaque, std::size_t opaque_len) { + auto h = SparseHandlePool::Borrow(); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; - const Gtsv2Descriptor& descriptor = - *UnpackDescriptor(opaque, opaque_len); + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const Gtsv2Descriptor& descriptor = **s; int m = descriptor.m; int n = descriptor.n; int ldb = descriptor.ldb; @@ -878,31 +1016,42 @@ void gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers, // TODO(b/182906199): Update the comment here once copy insertion is WAI. if (X != B) { size_t B_bytes = ldb * n * sizeof(T); - JAX_THROW_IF_ERROR( - cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream))); } - JAX_THROW_IF_ERROR( - computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer))); + return absl::OkStatus(); } void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - gtsv2(cusparseSgtsv2, stream, buffers, opaque, opaque_len); + std::size_t opaque_len, XlaCustomCallStatus* status) { + auto s = gtsv2(cusparseSgtsv2, stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } - void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - gtsv2(cusparseDgtsv2, stream, buffers, opaque, opaque_len); + std::size_t opaque_len, XlaCustomCallStatus* status) { + auto s = gtsv2(cusparseDgtsv2, stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } -template +template size_t Gtsv2BufferSize(F f, int m, int n, int ldb) { - auto handle = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(); + JAX_THROW_IF_ERROR(h.status()); + auto& handle = *h; size_t size; - JAX_THROW_IF_ERROR(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr, - /*du=*/nullptr, /*B=*/nullptr, ldb, &size)); + JAX_THROW_IF_ERROR( + JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr, + /*du=*/nullptr, /*B=*/nullptr, ldb, &size))); return size; } diff --git a/jaxlib/cusparse.py b/jaxlib/cusparse.py index ebf0abd99..6264c7ba4 100644 --- a/jaxlib/cusparse.py +++ b/jaxlib/cusparse.py @@ -59,6 +59,8 @@ def csr_todense(c, data, indices, indptr, *, shape): _Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)), )), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return _ops.GetTupleElement(out, 0) @@ -86,6 +88,8 @@ def csr_fromdense(c, mat, *, nnz, index_dtype): _Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)), )), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return tuple(_ops.GetTupleElement(out, i) for i in range(3)) @@ -122,6 +126,8 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d _Shape.array_shape(compute_dtype, (out_size,), (0,)), _Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return _ops.GetTupleElement(out, 0) @@ -158,6 +164,8 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d _Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)), _Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return _ops.GetTupleElement(out, 0) @@ -187,6 +195,8 @@ def coo_todense(c, data, row, col, *, shape): _Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)), )), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return _ops.GetTupleElement(out, 0) @@ -214,6 +224,8 @@ def coo_fromdense(c, mat, *, nnz, index_dtype): _Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)), )), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return tuple(_ops.GetTupleElement(out, i) for i in range(3)) @@ -249,6 +261,8 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No _Shape.array_shape(compute_dtype, (out_size,), (0,)), _Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return _ops.GetTupleElement(out, 0) @@ -285,6 +299,8 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No _Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)), _Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))), opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING, ) return _ops.GetTupleElement(out, 0) @@ -306,5 +322,7 @@ def gtsv2(c, dl, d, du, B, *, m, n, ldb, t): (_Shape.array_shape(np.dtype(t), (ldb, n), (1, 0)), _Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))), opaque=cusparse_kernels.build_gtsv2_descriptor(m, n, ldb), - has_side_effect=False) + has_side_effect=False, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return _ops.GetTupleElement(out, 0) diff --git a/jaxlib/handle_pool.h b/jaxlib/handle_pool.h index a75e000b0..dd8bb2fbc 100644 --- a/jaxlib/handle_pool.h +++ b/jaxlib/handle_pool.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" namespace jax { @@ -77,7 +78,7 @@ class HandlePool { // Borrows a handle from the pool. If 'stream' is non-null, sets the stream // associated with the handle. - static Handle Borrow(StreamType stream = nullptr); + static absl::StatusOr Borrow(StreamType stream = nullptr); private: static HandlePool* Instance(); diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index e2c7ba19c..192483765 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/status/statusor.h" namespace jax { @@ -36,9 +37,10 @@ std::string PackDescriptorAsString(const T& descriptor) { // Unpacks a descriptor object from a byte string. template -const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) { +absl::StatusOr UnpackDescriptor(const char* opaque, + std::size_t opaque_len) { if (opaque_len != sizeof(T)) { - throw std::runtime_error("Invalid size for linalg operation descriptor."); + return absl::InternalError("Invalid size for operation descriptor."); } return absl::bit_cast(opaque); } diff --git a/jaxlib/rocblas.cc b/jaxlib/rocblas.cc index e7ab05f69..b4d806294 100644 --- a/jaxlib/rocblas.cc +++ b/jaxlib/rocblas.cc @@ -30,6 +30,7 @@ limitations under the License. #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_pybind11_helpers.h" #include "jaxlib/rocm_gpu_kernel_helpers.h" +#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h" #include "rocm/include/hip/hip_runtime.h" #include "rocm/include/hip/hip_runtime_api.h" #include "rocm/include/rocblas.h" @@ -37,35 +38,36 @@ limitations under the License. namespace jax { + +absl::Status AsStatus(rocblas_status status) { + switch (status) { + case rocblas_status_success: + return absl::OkStatus(); + default: + return absl::InternalError(rocblas_status_to_string(status)); + } +} + namespace { namespace py = pybind11; -void ThrowIfErrorStatus(rocblas_status status) { - switch (status) { - case rocblas_status_success: - return; - default: - throw std::runtime_error(rocblas_status_to_string(status)); - } -} - using rocBlasHandlePool = HandlePool; template <> -/*static*/ rocBlasHandlePool::Handle rocBlasHandlePool::Borrow( +/*static*/ absl::StatusOr rocBlasHandlePool::Borrow( hipStream_t stream) { rocBlasHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); rocblas_handle handle; if (pool->handles_[stream].empty()) { - ThrowIfErrorStatus(rocblas_create_handle(&handle)); + JAX_RETURN_IF_ERROR(AsStatus(rocblas_create_handle(&handle))) } else { handle = pool->handles_[stream].back(); pool->handles_[stream].pop_back(); } if (stream) { - ThrowIfErrorStatus(rocblas_set_stream(handle, stream)); + JAX_RETURN_IF_ERROR(AsStatus(rocblas_set_stream(handle, stream))) } return rocBlasHandlePool::Handle(pool, handle, stream); } @@ -148,18 +150,21 @@ std::pair BuildTrsmDescriptor(const py::dtype& dtype, return {lwork, PackDescriptor(desc)}; } -void Trsm(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const TrsmDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = rocBlasHandlePool::Borrow(stream); +absl::Status Trsm_(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const TrsmDescriptor& d = **s; + auto h = rocBlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; // b is INOUT, so we copy the input to the output and use that if they are not // already the same if (buffers[2] != buffers[1]) { - ThrowIfError(hipMemcpyAsync(buffers[2], buffers[1], - SizeOfType(d.type) * d.batch * d.m * d.n, - hipMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync( + buffers[2], buffers[1], SizeOfType(d.type) * d.batch * d.m * d.n, + hipMemcpyDeviceToDevice, stream))) } const int lda = d.side == rocblas_side_left ? d.m : d.n; const int ldb = d.m; @@ -170,18 +175,18 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque, float* a = static_cast(buffers[0]); float* b = static_cast(buffers[2]); const float alpha = 1.0f; - ThrowIfErrorStatus(rocblas_strsm(handle.get(), d.side, d.uplo, d.trans, - d.diag, d.m, d.n, &alpha, - const_cast(a), lda, b, ldb)); + JAX_RETURN_IF_ERROR(AsStatus( + rocblas_strsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, + d.n, &alpha, const_cast(a), lda, b, ldb))) break; } case Type::F64: { double* a = static_cast(buffers[0]); double* b = static_cast(buffers[2]); const double alpha = 1.0; - ThrowIfErrorStatus(rocblas_dtrsm(handle.get(), d.side, d.uplo, d.trans, - d.diag, d.m, d.n, &alpha, - const_cast(a), lda, b, ldb)); + JAX_RETURN_IF_ERROR(AsStatus( + rocblas_dtrsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, + d.n, &alpha, const_cast(a), lda, b, ldb))) break; } case Type::C64: { @@ -190,9 +195,9 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque, rocblas_float_complex* b = static_cast(buffers[2]); const rocblas_float_complex alpha = {1.0f, 0.0f}; - ThrowIfErrorStatus(rocblas_ctrsm( + JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a), lda, b, ldb)); + const_cast(a), lda, b, ldb))) break; } case Type::C128: { @@ -200,10 +205,10 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque, static_cast(buffers[0]); rocblas_double_complex* b = static_cast(buffers[2]); - const rocblas_double_complex alpha = {1.0d, 0.0d}; - ThrowIfErrorStatus(rocblas_ztrsm( + const rocblas_double_complex alpha = {1.0f, 0.0f}; + JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, - const_cast(a), lda, b, ldb)); + const_cast(a), lda, b, ldb))) break; } } @@ -211,33 +216,35 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque, auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch, SizeOfType(d.type) * lda * lda); + JAX_RETURN_IF_ERROR(a_batch_host.status()); auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, SizeOfType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(b_batch_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - ThrowIfError(hipStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream))) switch (d.type) { case Type::F32: { float** a_batch_ptrs = static_cast(buffers[3]); float** b_batch_ptrs = static_cast(buffers[4]); const float alpha = 1.0f; - ThrowIfErrorStatus(rocblas_strsm_batched( + JAX_RETURN_IF_ERROR(AsStatus(rocblas_strsm_batched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch)); + d.batch))) break; } case Type::F64: { double** a_batch_ptrs = static_cast(buffers[3]); double** b_batch_ptrs = static_cast(buffers[4]); const double alpha = 1.0; - ThrowIfErrorStatus(rocblas_dtrsm_batched( + JAX_RETURN_IF_ERROR(AsStatus(rocblas_dtrsm_batched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, b_batch_ptrs, ldb, - d.batch)); + d.batch))) break; } case Type::C64: { @@ -246,10 +253,10 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque, rocblas_float_complex** b_batch_ptrs = static_cast(buffers[4]); const rocblas_float_complex alpha = {1.0f, 0.0f}; - ThrowIfErrorStatus(rocblas_ctrsm_batched( + JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm_batched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, - b_batch_ptrs, ldb, d.batch)); + b_batch_ptrs, ldb, d.batch))) break; } case Type::C128: { @@ -257,15 +264,25 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque, static_cast(buffers[3]); rocblas_double_complex** b_batch_ptrs = static_cast(buffers[4]); - const rocblas_double_complex alpha = {1.0d, 0.0d}; - ThrowIfErrorStatus(rocblas_ztrsm_batched( + const rocblas_double_complex alpha = {1.0f, 0.0f}; + JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm_batched( handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha, const_cast(a_batch_ptrs), lda, - b_batch_ptrs, ldb, d.batch)); + b_batch_ptrs, ldb, d.batch))) break; } } } + return absl::OkStatus(); +} + +void Trsm(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Trsm_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } //########################## @@ -290,17 +307,20 @@ std::pair BuildPotrfDescriptor(const py::dtype& dtype, return {lwork, PackDescriptor(PotrfDescriptor{type, uplo, b, n})}; } -void Potrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const PotrfDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = rocBlasHandlePool::Borrow(stream); +absl::Status Potrf_(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const PotrfDescriptor& d = **s; + auto h = rocBlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; // a is INOUT, so we copy the input to the output and use that if they are not // already the same if (buffers[1] != buffers[0]) { - ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0], - SizeOfType(d.type) * d.batch * d.n * d.n, - hipMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync( + buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n, + hipMemcpyDeviceToDevice, stream))) } int* info = static_cast(buffers[2]); @@ -308,28 +328,28 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque, switch (d.type) { case Type::F32: { float* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_spotrf(handle.get(), d.uplo, d.n, a, d.n, info)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_spotrf(handle.get(), d.uplo, d.n, a, d.n, info))) break; } case Type::F64: { double* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_dpotrf(handle.get(), d.uplo, d.n, a, d.n, info)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_dpotrf(handle.get(), d.uplo, d.n, a, d.n, info))) break; } case Type::C64: { rocblas_float_complex* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_cpotrf(handle.get(), d.uplo, d.n, a, d.n, info)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_cpotrf(handle.get(), d.uplo, d.n, a, d.n, info))) break; } case Type::C128: { rocblas_double_complex* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_zpotrf(handle.get(), d.uplo, d.n, a, d.n, info)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_zpotrf(handle.get(), d.uplo, d.n, a, d.n, info))) break; } } @@ -337,40 +357,51 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque, auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, SizeOfType(d.type) * d.n * d.n); + JAX_RETURN_IF_ERROR(a_ptrs_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - ThrowIfError(hipStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream))) switch (d.type) { case Type::F32: { float** a_batch_ptrs = static_cast(buffers[3]); - ThrowIfErrorStatus(rocsolver_spotrf_batched( - handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)); + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_spotrf_batched( + handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch))) break; } case Type::F64: { double** a_batch_ptrs = static_cast(buffers[3]); - ThrowIfErrorStatus(rocsolver_dpotrf_batched( - handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)); + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dpotrf_batched( + handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch))) break; } case Type::C64: { rocblas_float_complex** a_batch_ptrs = static_cast(buffers[3]); - ThrowIfErrorStatus(rocsolver_cpotrf_batched( - handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)); + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cpotrf_batched( + handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch))) break; } case Type::C128: { rocblas_double_complex** a_batch_ptrs = static_cast(buffers[3]); - ThrowIfErrorStatus(rocsolver_zpotrf_batched( - handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)); + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zpotrf_batched( + handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch))) break; } } } + return absl::OkStatus(); +} + +void Potrf(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Potrf_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // getrf: LU decomposition @@ -389,18 +420,21 @@ std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})}; } -void Getrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GetrfDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = rocBlasHandlePool::Borrow(stream); +absl::Status Getrf_(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GetrfDescriptor& d = **s; + auto h = rocBlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; // a is INOUT, so we copy the input to the output and use that if they are not // already the same if (buffers[1] != buffers[0]) { - ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0], - SizeOfType(d.type) * d.batch * d.m * d.n, - hipMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync( + buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n, + hipMemcpyDeviceToDevice, stream))) } int* ipiv = static_cast(buffers[2]); @@ -410,28 +444,28 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque, switch (d.type) { case Type::F32: { float* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_sgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_sgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info))) break; } case Type::F64: { double* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_dgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_dgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info))) break; } case Type::C64: { rocblas_float_complex* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_cgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_cgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info))) break; } case Type::C128: { rocblas_double_complex* a = static_cast(buffers[1]); - ThrowIfErrorStatus( - rocsolver_zgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_zgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info))) break; } } @@ -439,44 +473,55 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque, auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch, SizeOfType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(a_ptrs_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - ThrowIfError(hipStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream))) switch (d.type) { case Type::F32: { float** batch_ptrs = static_cast(buffers[4]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_sgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - ipiv, std::min(d.m, d.n), info, d.batch)); + ipiv, std::min(d.m, d.n), info, d.batch))) break; } case Type::F64: { double** batch_ptrs = static_cast(buffers[4]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_dgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - ipiv, std::min(d.m, d.n), info, d.batch)); + ipiv, std::min(d.m, d.n), info, d.batch))) break; } case Type::C64: { rocblas_float_complex** batch_ptrs = static_cast(buffers[4]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_cgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - ipiv, std::min(d.m, d.n), info, d.batch)); + ipiv, std::min(d.m, d.n), info, d.batch))) break; } case Type::C128: { rocblas_double_complex** batch_ptrs = static_cast(buffers[4]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_zgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - ipiv, std::min(d.m, d.n), info, d.batch)); + ipiv, std::min(d.m, d.n), info, d.batch))) break; } } } + return absl::OkStatus(); +} + +void Getrf(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Getrf_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // geqrf: QR decomposition @@ -494,18 +539,21 @@ std::pair BuildGeqrfDescriptor(const py::dtype& dtype, int b, return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n})}; } -void Geqrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GeqrfDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = rocBlasHandlePool::Borrow(stream); +absl::Status Geqrf_(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GeqrfDescriptor& d = **s; + auto h = rocBlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; // a is INOUT, so we copy the input to the output and use that if they are not // already the same if (buffers[1] != buffers[0]) { - ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0], - SizeOfType(d.type) * d.batch * d.m * d.n, - hipMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync( + buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n, + hipMemcpyDeviceToDevice, stream))) } // here tau is tau @@ -515,15 +563,15 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque, case Type::F32: { float* a = static_cast(buffers[1]); float* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( - rocsolver_sgeqrf(handle.get(), d.m, d.n, a, d.m, tau)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_sgeqrf(handle.get(), d.m, d.n, a, d.m, tau))) break; } case Type::F64: { double* a = static_cast(buffers[1]); double* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( - rocsolver_dgeqrf(handle.get(), d.m, d.n, a, d.m, tau)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_dgeqrf(handle.get(), d.m, d.n, a, d.m, tau))) break; } case Type::C64: { @@ -531,8 +579,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque, static_cast(buffers[1]); rocblas_float_complex* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( - rocsolver_cgeqrf(handle.get(), d.m, d.n, a, d.m, tau)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_cgeqrf(handle.get(), d.m, d.n, a, d.m, tau))) break; } case Type::C128: { @@ -540,8 +588,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque, static_cast(buffers[1]); rocblas_double_complex* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( - rocsolver_zgeqrf(handle.get(), d.m, d.n, a, d.m, tau)); + JAX_RETURN_IF_ERROR( + AsStatus(rocsolver_zgeqrf(handle.get(), d.m, d.n, a, d.m, tau))) break; } } @@ -549,26 +597,27 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque, auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, SizeOfType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(a_ptrs_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - ThrowIfError(hipStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream))) switch (d.type) { case Type::F32: { float** batch_ptrs = static_cast(buffers[3]); float* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_sgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - tau, std::min(d.m, d.n), d.batch)); + tau, std::min(d.m, d.n), d.batch))) break; } case Type::F64: { double** batch_ptrs = static_cast(buffers[3]); double* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_dgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - tau, std::min(d.m, d.n), d.batch)); + tau, std::min(d.m, d.n), d.batch))) break; } case Type::C64: { @@ -576,9 +625,9 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque, static_cast(buffers[3]); rocblas_float_complex* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_cgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - tau, std::min(d.m, d.n), d.batch)); + tau, std::min(d.m, d.n), d.batch))) break; } case Type::C128: { @@ -586,13 +635,23 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque, static_cast(buffers[3]); rocblas_double_complex* tau = static_cast(buffers[2]); - ThrowIfErrorStatus( + JAX_RETURN_IF_ERROR(AsStatus( rocsolver_zgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m, - tau, std::min(d.m, d.n), d.batch)); + tau, std::min(d.m, d.n), d.batch))) break; } } } + return absl::OkStatus(); +} + +void Geqrf(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Geqrf_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // orgqr/ungqr: apply elementary Householder transformations @@ -608,18 +667,21 @@ std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k})}; } -void Orgqr(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const OrgqrDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = rocBlasHandlePool::Borrow(stream); +absl::Status Orgqr_(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const OrgqrDescriptor& d = **s; + auto h = rocBlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; // a is INOUT, so we copy the input to the output and use that if they are not // already the same if (buffers[2] != buffers[0]) { - ThrowIfError(hipMemcpyAsync(buffers[2], buffers[0], - SizeOfType(d.type) * d.batch * d.m * d.n, - hipMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync( + buffers[2], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n, + hipMemcpyDeviceToDevice, stream))) } switch (d.type) { @@ -629,8 +691,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque, float* a = static_cast(buffers[2]); float* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - ThrowIfErrorStatus( - rocsolver_sorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_sorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau))) a += d.m * d.n; tau += d.k; } @@ -640,8 +702,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque, double* a = static_cast(buffers[2]); double* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - ThrowIfErrorStatus( - rocsolver_dorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_dorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau))) a += d.m * d.n; tau += d.k; } @@ -656,8 +718,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque, rocblas_float_complex* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - ThrowIfErrorStatus( - rocsolver_cungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_cungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau))) a += d.m * d.n; tau += d.k; } @@ -669,14 +731,24 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque, rocblas_double_complex* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - ThrowIfErrorStatus( - rocsolver_zungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_zungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau))) a += d.m * d.n; tau += d.k; } break; } } + return absl::OkStatus(); +} + +void Orgqr(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Orgqr_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd @@ -715,18 +787,21 @@ std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, return {lwork, PackDescriptor(GesvdDescriptor{type, b, m, n, jobu, jobvt})}; } -void Gesvd(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - const GesvdDescriptor& d = - *UnpackDescriptor(opaque, opaque_len); - auto handle = rocBlasHandlePool::Borrow(stream); +absl::Status Gesvd_(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + auto s = UnpackDescriptor(opaque, opaque_len); + JAX_RETURN_IF_ERROR(s.status()); + const GesvdDescriptor& d = **s; + auto h = rocBlasHandlePool::Borrow(stream); + JAX_RETURN_IF_ERROR(h.status()); + auto& handle = *h; // a is INOUT, so we copy the input to the output and use that if they are not // already the same if (buffers[1] != buffers[0]) { - ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0], - SizeOfType(d.type) * d.batch * d.m * d.n, - hipMemcpyDeviceToDevice, stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync( + buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n, + hipMemcpyDeviceToDevice, stream))) } int* info = static_cast(buffers[5]); @@ -743,9 +818,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, float* u = static_cast(buffers[3]); float* vt = static_cast(buffers[4]); float* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_sgesvd(handle.get(), d.jobu, d.jobvt, d.m, - d.n, a, lda, s, u, ldu, vt, ldv, e, - rocblas_inplace, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_sgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s, + u, ldu, vt, ldv, e, rocblas_inplace, info))) break; } case Type::F64: { @@ -754,9 +829,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, double* u = static_cast(buffers[3]); double* vt = static_cast(buffers[4]); double* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_dgesvd(handle.get(), d.jobu, d.jobvt, d.m, - d.n, a, lda, s, u, ldu, vt, ldv, e, - rocblas_inplace, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_dgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s, + u, ldu, vt, ldv, e, rocblas_inplace, info))) break; } case Type::C64: { @@ -768,9 +843,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, rocblas_float_complex* vt = static_cast(buffers[4]); float* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_cgesvd(handle.get(), d.jobu, d.jobvt, d.m, - d.n, a, lda, s, u, ldu, vt, ldv, e, - rocblas_inplace, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_cgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s, + u, ldu, vt, ldv, e, rocblas_inplace, info))) break; } case Type::C128: { @@ -782,9 +857,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, rocblas_double_complex* vt = static_cast(buffers[4]); double* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_zgesvd(handle.get(), d.jobu, d.jobvt, d.m, - d.n, a, lda, s, u, ldu, vt, ldv, e, - rocblas_inplace, info)); + JAX_RETURN_IF_ERROR(AsStatus( + rocsolver_zgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s, + u, ldu, vt, ldv, e, rocblas_inplace, info))) break; } } @@ -797,10 +872,11 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[7], d.batch, SizeOfType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(a_ptrs_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - ThrowIfError(hipStreamSynchronize(stream)); + JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream))) switch (d.type) { case Type::F32: { @@ -809,10 +885,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, float* u = static_cast(buffers[3]); float* vt = static_cast(buffers[4]); float* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_sgesvd_batched( + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_sgesvd_batched( handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s, stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e, - rocblas_inplace, info, d.batch)); + rocblas_inplace, info, d.batch))) break; } case Type::F64: { @@ -821,10 +897,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, double* u = static_cast(buffers[3]); double* vt = static_cast(buffers[4]); double* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_dgesvd_batched( + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dgesvd_batched( handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s, stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e, - rocblas_inplace, info, d.batch)); + rocblas_inplace, info, d.batch))) break; } case Type::C64: { @@ -836,10 +912,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, rocblas_float_complex* vt = static_cast(buffers[4]); float* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_cgesvd_batched( + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cgesvd_batched( handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s, stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e, - rocblas_inplace, info, d.batch)); + rocblas_inplace, info, d.batch))) break; } case Type::C128: { @@ -851,14 +927,24 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque, rocblas_double_complex* vt = static_cast(buffers[4]); double* e = static_cast(buffers[6]); - ThrowIfErrorStatus(rocsolver_zgesvd_batched( + JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zgesvd_batched( handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s, stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e, - rocblas_inplace, info, d.batch)); + rocblas_inplace, info, d.batch))) break; } } } + return absl::OkStatus(); +} + +void Gesvd(hipStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = Gesvd_(stream, buffers, opaque, opaque_len); + if (!s.ok()) { + XlaCustomCallStatusSetFailure(status, s.error_message().c_str(), + s.error_message().length()); + } } // Singular value decomposition using Jacobi algorithm: gesvdj diff --git a/jaxlib/rocm_gpu_kernel_helpers.cc b/jaxlib/rocm_gpu_kernel_helpers.cc index fd0e6a6c0..30294e3e7 100644 --- a/jaxlib/rocm_gpu_kernel_helpers.cc +++ b/jaxlib/rocm_gpu_kernel_helpers.cc @@ -22,24 +22,26 @@ limitations under the License. namespace jax { -void ThrowIfError(hipError_t error) { +absl::Status AsStatus(hipError_t error) { if (error != hipSuccess) { - throw std::runtime_error( + return absl::InternalError( absl::StrCat("ROCm operation failed: ", hipGetErrorString(error))); } + return absl::OkStatus(); } -std::unique_ptr MakeBatchPointers(hipStream_t stream, void* buffer, - void* dev_ptrs, int batch, - int batch_elem_size) { +absl::StatusOr> MakeBatchPointers( + hipStream_t stream, void* buffer, void* dev_ptrs, int batch, + int batch_elem_size) { char* ptr = static_cast(buffer); auto host_ptrs = absl::make_unique(batch); for (int i = 0; i < batch; ++i) { host_ptrs[i] = ptr; ptr += batch_elem_size; } - ThrowIfError(hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, - hipMemcpyHostToDevice, stream)); + JAX_RETURN_IF_ERROR( + AsStatus(hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, + hipMemcpyHostToDevice, stream))); return host_ptrs; } } // namespace jax diff --git a/jaxlib/rocm_gpu_kernel_helpers.h b/jaxlib/rocm_gpu_kernel_helpers.h index a80a08fb8..dde4a8357 100644 --- a/jaxlib/rocm_gpu_kernel_helpers.h +++ b/jaxlib/rocm_gpu_kernel_helpers.h @@ -19,18 +19,28 @@ limitations under the License. #include #include "rocm/include/hip/hip_runtime_api.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +#define JAX_RETURN_IF_ERROR(expr) \ + { \ + auto s___ = (expr); \ + if (!s___.ok()) return s___; \ + } namespace jax { -void ThrowIfError(hipError_t error); +absl::Status AsStatus(hipError_t error); // Builds an array of pointers to each array in a batch, in device memory. // Caution: the return value must be kept alive (e.g., via a stream // synchronization) until the copy enqueued by MakeBatchPointers on `stream` // completes. -std::unique_ptr MakeBatchPointers(hipStream_t stream, void* buffer, - void* dev_ptrs, int batch, - int batch_elem_size); +absl::StatusOr> MakeBatchPointers(hipStream_t stream, + void* buffer, + void* dev_ptrs, + int batch, + int batch_elem_size); } // namespace jax diff --git a/jaxlib/rocsolver.py b/jaxlib/rocsolver.py index 4d0952246..02a07c09a 100644 --- a/jaxlib/rocsolver.py +++ b/jaxlib/rocsolver.py @@ -96,7 +96,9 @@ def trsm(c, _Shape.array_shape(dtype, a_shape.dimensions(), layout), # buffers[0] (a) _Shape.array_shape(dtype, b_shape.dimensions(), layout), # buffers[1] (b, IN) ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return _ops.GetTupleElement(out, 0) @@ -133,7 +135,9 @@ def potrf(c, a, lower): (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) ), # buffers[0] (a, IN) ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1) @@ -171,7 +175,9 @@ def getrf(c, a): (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) ), # buffers[0] (a, IN) ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2)) @@ -208,7 +214,9 @@ def geqrf(c, a): (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) ), # buffers[0] (a, IN) ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) # rocsolver geqrf does not return info return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), None) @@ -247,7 +255,9 @@ def orgqr(c, a, tau): _Shape.array_shape(dtype, batch_dims + (k,), tuple(range(num_bd, -1, -1))), # buffers[1] (tau IN) ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) return (_ops.GetTupleElement(out, 0), None) # ROCSolver orgqr does not return info @@ -303,7 +313,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True): _Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout), # buffers[0] (a, IN) ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) s = _ops.GetTupleElement(out, 1) vt = _ops.GetTupleElement(out, 2) u = _ops.GetTupleElement(out, 3) @@ -338,7 +350,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True): _Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout), # buffers[0] (a, IN) ), - opaque=opaque) + opaque=opaque, + api_version=xla_client.ops.CustomCallApiVersion + .API_VERSION_STATUS_RETURNING) s = _ops.GetTupleElement(out, 1) u = _ops.GetTupleElement(out, 2) vt = _ops.GetTupleElement(out, 3)