Use the new "custom call status" facility to report errors in jaxlib

PiperOrigin-RevId: 389734200
This commit is contained in:
Aden Grue 2021-08-09 15:06:12 -07:00 committed by jax authors
parent f04464d210
commit c368969955
20 changed files with 1359 additions and 823 deletions

View File

@ -54,6 +54,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/base",
"@com_google_absl//absl/status:statusor",
],
)
@ -67,6 +68,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
],
)
@ -82,8 +84,9 @@ cc_library(
deps = [
"@org_tensorflow//tensorflow/stream_executor/cuda:cusolver_lib",
"@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cublas_headers",
@ -102,6 +105,8 @@ cc_library(
deps = [
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:rocm_headers",
],
@ -159,6 +164,7 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":handle_pool",
":kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/stream_executor/cuda:cublas_lib",
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
"@com_google_absl//absl/algorithm:container",
@ -189,6 +195,7 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":handle_pool",
":kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
"@org_tensorflow//tensorflow/stream_executor/cuda:cusolver_lib",
"@com_google_absl//absl/algorithm:container",
@ -218,6 +225,7 @@ pybind_extension(
":cuda_gpu_kernel_helpers",
":handle_pool",
":kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
"@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/algorithm:container",
@ -241,6 +249,8 @@ cuda_library(
deps = [
":cuda_gpu_kernel_helpers",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -270,6 +280,8 @@ cuda_library(
deps = [
":cuda_gpu_kernel_helpers",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/status",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -306,6 +318,7 @@ pybind_extension(
":handle_pool",
":kernel_pybind11_helpers",
":rocm_gpu_kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace {
@ -41,18 +42,19 @@ namespace py = pybind11;
using BlasHandlePool = HandlePool<cublasHandle_t, cudaStream_t>;
template <>
/*static*/ BlasHandlePool::Handle BlasHandlePool::Borrow(cudaStream_t stream) {
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
cudaStream_t stream) {
BlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cublasHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_THROW_IF_ERROR(cublasCreate(&handle));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_THROW_IF_ERROR(cublasSetStream(handle, stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
@ -122,26 +124,31 @@ std::pair<size_t, py::bytes> BuildTrsmBatchedDescriptor(
return {size, PackDescriptor(desc)};
}
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const TrsmBatchedDescriptor& d =
*UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
auto handle = BlasHandlePool::Borrow(stream);
absl::Status TrsmBatched_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const TrsmBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[1]) {
JAX_THROW_IF_ERROR(cudaMemcpyAsync(buffers[2], buffers[1],
SizeOfType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[2], buffers[1], SizeOfType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n;
const int ldb = d.m;
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
SizeOfType(d.type) * lda * lda);
JAX_RETURN_IF_ERROR(a_batch_host.status());
auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(b_batch_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_THROW_IF_ERROR(cudaStreamSynchronize(stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[0]);
@ -150,10 +157,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
// NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault.
const float alpha = 1.0f;
JAX_THROW_IF_ERROR(cublasStrsmBatched(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasStrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch));
d.batch)));
break;
}
case Type::F64: {
@ -162,10 +169,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
const double alpha = 1.0;
JAX_THROW_IF_ERROR(cublasDtrsmBatched(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch));
d.batch)));
break;
}
case Type::C64: {
@ -174,10 +181,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
cuComplex** a_batch_ptrs = static_cast<cuComplex**>(buffers[3]);
cuComplex** b_batch_ptrs = static_cast<cuComplex**>(buffers[4]);
const cuComplex alpha = make_cuComplex(1.0f, 0.0f);
JAX_THROW_IF_ERROR(cublasCtrsmBatched(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const cuComplex**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch));
d.batch)));
break;
}
case Type::C128: {
@ -188,13 +195,23 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
cuDoubleComplex** b_batch_ptrs =
static_cast<cuDoubleComplex**>(buffers[4]);
const cuDoubleComplex alpha = make_cuDoubleComplex(1.0f, 0.0f);
JAX_THROW_IF_ERROR(cublasZtrsmBatched(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<const cuDoubleComplex**>(a_batch_ptrs), lda, b_batch_ptrs,
ldb, d.batch));
ldb, d.batch)));
break;
}
}
return absl::OkStatus();
}
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// Batched LU decomposition: getrfbatched
@ -212,55 +229,69 @@ std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
}
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const GetrfBatchedDescriptor& d =
*UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
auto handle = BlasHandlePool::Borrow(stream);
absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GetrfBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_THROW_IF_ERROR(cudaMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.n * d.n,
cudaMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
SizeOfType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_THROW_IF_ERROR(cudaStreamSynchronize(stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float** batch_ptrs = static_cast<float**>(buffers[4]);
JAX_THROW_IF_ERROR(cublasSgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
ipiv, info, d.batch));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double** batch_ptrs = static_cast<double**>(buffers[4]);
JAX_THROW_IF_ERROR(cublasDgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
ipiv, info, d.batch));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case Type::C64: {
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
cuComplex** batch_ptrs = static_cast<cuComplex**>(buffers[4]);
JAX_THROW_IF_ERROR(cublasCgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
ipiv, info, d.batch));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case Type::C128: {
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
cuDoubleComplex** batch_ptrs = static_cast<cuDoubleComplex**>(buffers[4]);
JAX_THROW_IF_ERROR(cublasZgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
ipiv, info, d.batch));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
}
return absl::OkStatus();
}
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
py::dict Registrations() {

View File

@ -23,15 +23,13 @@ limitations under the License.
namespace jax {
namespace {
std::string ErrorToString(cudaError_t error) {
return cudaGetErrorString(error);
}
std::string ErrorString(cudaError_t error) { return cudaGetErrorString(error); }
std::string ErrorToString(cusparseStatus_t status) {
std::string ErrorString(cusparseStatus_t status) {
return cusparseGetErrorString(status);
}
std::string ErrorToString(cusolverStatus_t status) {
std::string ErrorString(cusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS:
return "cuSolver success.";
@ -62,7 +60,7 @@ std::string ErrorToString(cusolverStatus_t status) {
}
}
std::string ErrorToString(cublasStatus_t status) {
std::string ErrorString(cublasStatus_t status) {
switch (status) {
case CUBLAS_STATUS_SUCCESS:
return "cuBlas success";
@ -90,47 +88,53 @@ std::string ErrorToString(cublasStatus_t status) {
}
template <typename T>
void ThrowError(T status, const char* file, std::int64_t line,
const char* expr) {
throw std::runtime_error(absl::StrFormat("%s:%d: operation %s failed: %s",
file, line, expr,
ErrorToString(status)));
std::string ErrorString(T status, const char* file, std::int64_t line,
const char* expr) {
return absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr,
ErrorString(status));
}
} // namespace
void ThrowIfError(cudaError_t error, const char* file, std::int64_t line,
const char* expr) {
if (error != cudaSuccess) ThrowError(error, file, line, expr);
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
const char* expr) {
if (error != cudaSuccess)
return absl::InternalError(ErrorString(error, file, line, expr));
return absl::OkStatus();
}
void ThrowIfError(cusolverStatus_t status, const char* file, std::int64_t line,
const char* expr) {
if (status != CUSOLVER_STATUS_SUCCESS) ThrowError(status, file, line, expr);
absl::Status AsStatus(cusolverStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != CUSOLVER_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
void ThrowIfError(cusparseStatus_t status, const char* file, std::int64_t line,
const char* expr) {
if (status != CUSPARSE_STATUS_SUCCESS) ThrowError(status, file, line, expr);
absl::Status AsStatus(cusparseStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != CUSPARSE_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
void ThrowIfError(cublasStatus_t status, const char* file, std::int64_t line,
const char* expr) {
if (status != CUBLAS_STATUS_SUCCESS) ThrowError(status, file, line, expr);
absl::Status AsStatus(cublasStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != CUBLAS_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
std::unique_ptr<void* []> MakeBatchPointers(cudaStream_t stream, void* buffer,
void* dev_ptrs, int batch,
int batch_elem_size) {
absl::StatusOr<std::unique_ptr<void* []>> MakeBatchPointers(
cudaStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
host_ptrs[i] = ptr;
ptr += batch_elem_size;
}
JAX_THROW_IF_ERROR(cudaMemcpyAsync(dev_ptrs, host_ptrs.get(),
sizeof(void*) * batch,
cudaMemcpyHostToDevice, stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudaMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
cudaMemcpyHostToDevice, stream)));
return host_ptrs;
}
} // namespace jax

View File

@ -18,33 +18,48 @@ limitations under the License.
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#define JAX_THROW_IF_ERROR(expr) \
jax::ThrowIfError(expr, __FILE__, __LINE__, #expr)
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
#define JAX_THROW_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) throw std::runtime_error(s___.error_message()); \
}
#define JAX_RETURN_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) return s___; \
}
namespace jax {
// Used via JAX_THROW_IF_ERROR(expr) macro.
void ThrowIfError(cudaError_t error, const char* file, std::int64_t line,
const char* expr);
void ThrowIfError(cusolverStatus_t status, const char* file, std::int64_t line,
const char* expr);
void ThrowIfError(cusparseStatus_t status, const char* file, std::int64_t line,
const char* expr);
void ThrowIfError(cublasStatus_t status, const char* file, std::int64_t line,
const char* expr);
// Used via JAX_AS_STATUS(expr) macro.
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
const char* expr);
absl::Status AsStatus(cusolverStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(cusparseStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(cublasStatus_t status, const char* file,
std::int64_t line, const char* expr);
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
// completes.
std::unique_ptr<void*[]> MakeBatchPointers(cudaStream_t stream, void* buffer,
void* dev_ptrs, int batch,
int batch_elem_size);
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(cudaStream_t stream,
void* buffer,
void* dev_ptrs,
int batch,
int batch_elem_size);
} // namespace jax

View File

@ -58,4 +58,6 @@ def lu_pivots_to_permutation(c, pivots, *, permutation_size):
operands=(pivots,),
shape_with_layout=permutations_shape_with_layout,
operand_shapes_with_layout=(pivots_shape_with_layout,),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace {
@ -73,13 +74,16 @@ std::string BuildCudaLuPivotsToPermutationDescriptor(
batch_size, pivot_size, permutation_size});
}
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
absl::Status CudaLuPivotsToPermutation_(cudaStream_t stream, void** buffers,
const char* opaque,
std::size_t opaque_len) {
const std::int32_t* pivots =
reinterpret_cast<const std::int32_t*>(buffers[0]);
std::int32_t* permutation_out = reinterpret_cast<std::int32_t*>(buffers[1]);
const auto& descriptor =
*UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
auto s =
UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const auto& descriptor = **s;
const int block_dim = 128;
const std::int64_t grid_dim = std::min<std::int64_t>(
@ -89,7 +93,18 @@ void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
/*dynamic_shared_mem_bytes=*/0, stream>>>(
pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size,
descriptor.permutation_size);
JAX_THROW_IF_ERROR(cudaGetLastError());
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
return absl::OkStatus();
}
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status) {
auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
} // namespace jax

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
@ -28,7 +29,8 @@ std::string BuildCudaLuPivotsToPermutationDescriptor(
std::int32_t permutation_size);
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len);
const char* opaque, std::size_t opaque_len,
XlaCustomCallStatus* status);
} // namespace jax

View File

@ -56,4 +56,6 @@ def threefry2x32(c, keys, data):
operands=(keys[0], keys[1], data[0], data[1]),
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),
operand_shapes_with_layout=(shape,) * 4,
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)

View File

@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda_prng_kernels.h"
#include <array>
#include <cstddef>
#include "jaxlib/cuda_prng_kernels.h"
#include "jaxlib/cuda_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace {
@ -106,8 +108,8 @@ std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
}
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
absl::Status CudaThreeFry2x32_(cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
std::array<const std::uint32_t*, 2> keys;
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
@ -117,15 +119,26 @@ void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
std::array<std::uint32_t*, 2> out;
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
const auto& descriptor =
*UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const auto& descriptor = **s;
const int block_dim = 128;
const std::int64_t grid_dim =
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
out[1], descriptor.n);
JAX_THROW_IF_ERROR(cudaGetLastError());
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
return absl::OkStatus();
}
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
} // namespace jax

View File

@ -20,13 +20,14 @@ limitations under the License.
#include <string>
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n);
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len);
std::size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax

File diff suppressed because it is too large Load Diff

View File

@ -97,7 +97,9 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
_Shape.array_shape(dtype, a_shape.dimensions(), layout),
_Shape.array_shape(dtype, b_shape.dimensions(), layout),
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0)
@ -131,7 +133,9 @@ def potrf(c, a, lower):
operand_shapes_with_layout=(_Shape.array_shape(
dtype, batch_dims + (n, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)
@ -175,7 +179,9 @@ def getrf(c, a):
operand_shapes_with_layout=(_Shape.array_shape(
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
@ -213,7 +219,9 @@ def geqrf(c, a):
operand_shapes_with_layout=(_Shape.array_shape(
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
@ -257,7 +265,9 @@ def orgqr(c, a, tau):
dtype, batch_dims + (k,),
tuple(range(num_bd, -1, -1))),
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1))
@ -302,7 +312,9 @@ def syevd(c, a, lower=False):
operand_shapes_with_layout=(
_Shape.array_shape(dtype, dims, layout),
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
@ -342,7 +354,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
u = _ops.GetTupleElement(out, 2)
v = _ops.GetTupleElement(out, 3)
@ -371,7 +385,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
vt = _ops.GetTupleElement(out, 2)
u = _ops.GetTupleElement(out, 3)
@ -398,7 +414,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
u = _ops.GetTupleElement(out, 2)
vt = _ops.GetTupleElement(out, 3)

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
@ -148,19 +149,19 @@ CudaConst CudaOne(cudaDataType type) {
using SparseHandlePool = HandlePool<cusparseHandle_t, cudaStream_t>;
template <>
/*static*/ SparseHandlePool::Handle SparseHandlePool::Borrow(
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
cudaStream_t stream) {
SparseHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cusparseHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_THROW_IF_ERROR(cusparseCreate(&handle));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_THROW_IF_ERROR(cusparseSetStream(handle, stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
@ -246,7 +247,9 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
@ -258,47 +261,60 @@ std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(cusparseSparseToDense_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size));
&buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
d.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3],
d.value_type, CUSPARSE_ORDER_ROW));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
buffers[4]));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// CsrFromDense: Convert dense matrix to CSR matrix
@ -307,7 +323,9 @@ void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
@ -317,48 +335,61 @@ std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(cusparseDenseToSparse_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size));
&buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CsrFromDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0],
d.value_type, CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
d.value_type));
JAX_THROW_IF_ERROR(cusparseDenseToSparse_analysis(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
JAX_THROW_IF_ERROR(cusparseDenseToSparse_convert(
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// CsrMatvec: Product of CSR matrix and dense vector.
@ -374,7 +405,9 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseVecDescriptor x =
@ -391,30 +424,35 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
CudaConst alpha = CudaOne(y.type);
CudaConst beta = CudaZero(y.type);
JAX_THROW_IF_ERROR(cusparseSpMV_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
}
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CsrMatvecDescriptor& d =
*UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CsrMatvec_(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CsrMatvecDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
@ -434,20 +472,32 @@ void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
JAX_THROW_IF_ERROR(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_THROW_IF_ERROR(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
&beta, vec_y, d.y.type,
CUSPARSE_MV_ALG_DEFAULT, buf));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// CsrMatmat: Product of CSR matrix and dense matrix.
@ -463,7 +513,9 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseMatDescriptor B =
@ -480,32 +532,37 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW)));
size_t buffer_size;
CudaConst alpha = CudaOne(C.type);
CudaConst beta = CudaZero(C.type);
JAX_THROW_IF_ERROR(cusparseSpMM_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
}
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CsrMatmatDescriptor& d =
*UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CsrMatmat_(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CsrMatmatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
@ -525,23 +582,33 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
JAX_THROW_IF_ERROR(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseSpMM(
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// CooToDense: Convert COO matrix to dense matrix
@ -550,7 +617,9 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
@ -561,46 +630,60 @@ std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(cusparseSparseToDense_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size));
&buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CooToDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[1],
/*cooColInd=*/buffers[2],
/*cooValues=*/buffers[0], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3],
d.value_type, CUSPARSE_ORDER_ROW));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[1],
/*cooColInd=*/buffers[2],
/*cooValues=*/buffers[0], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
buffers[4]));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// CooFromDense: Convert dense matrix to COO matrix
@ -609,7 +692,9 @@ void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
@ -619,47 +704,61 @@ std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty,
empty, empty, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(cusparseDenseToSparse_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size));
&buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SparseMatDescriptor& d =
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CooFromDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0],
d.value_type, CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[2],
/*cooColInd=*/buffers[3],
/*cooValues=*/buffers[1], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
JAX_THROW_IF_ERROR(cusparseDenseToSparse_analysis(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[2],
/*cooColInd=*/buffers[3],
/*cooValues=*/buffers[1], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
JAX_THROW_IF_ERROR(cusparseDenseToSparse_convert(
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4]));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// CooMatvec: Product of COO matrix and dense vector.
@ -675,7 +774,9 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseVecDescriptor x =
@ -692,30 +793,35 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
CudaConst alpha = CudaOne(y.type);
CudaConst beta = CudaZero(y.type);
JAX_THROW_IF_ERROR(cusparseSpMV_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
}
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CooMatvecDescriptor& d =
*UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CooMatvec_(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CooMatvecDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
@ -735,19 +841,31 @@ void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
JAX_THROW_IF_ERROR(cusparseCreateCoo(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_THROW_IF_ERROR(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
&beta, vec_y, d.y.type,
CUSPARSE_MV_ALG_DEFAULT, buf));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// CooMatmat: Product of COO matrix and dense matrix.
@ -763,7 +881,9 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseMatDescriptor B =
@ -780,32 +900,37 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
empty, empty, A.index_type,
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW)));
size_t buffer_size;
CudaConst alpha = CudaOne(C.type);
CudaConst beta = CudaZero(C.type);
JAX_THROW_IF_ERROR(cusparseSpMM_bufferSize(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
}
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const CooMatmatDescriptor& d =
*UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
auto handle = SparseHandlePool::Borrow(stream);
absl::Status CooMatmat_(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CooMatmatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
@ -825,22 +950,32 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
JAX_THROW_IF_ERROR(cusparseCreateCoo(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type,
CUSPARSE_ORDER_ROW));
JAX_THROW_IF_ERROR(cusparseSpMM(
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
#endif // if JAX_CUSPARSE_11030
@ -853,12 +988,15 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
}
template <typename T, typename F>
void gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto handle = SparseHandlePool::Borrow();
absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto h = SparseHandlePool::Borrow();
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
const Gtsv2Descriptor& descriptor =
*UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
auto s = UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const Gtsv2Descriptor& descriptor = **s;
int m = descriptor.m;
int n = descriptor.n;
int ldb = descriptor.ldb;
@ -878,31 +1016,42 @@ void gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
if (X != B) {
size_t B_bytes = ldb * n * sizeof(T);
JAX_THROW_IF_ERROR(
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream)));
}
JAX_THROW_IF_ERROR(
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer)));
return absl::OkStatus();
}
void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
template<typename F>
template <typename F>
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
auto handle = SparseHandlePool::Borrow();
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
size_t size;
JAX_THROW_IF_ERROR(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
/*du=*/nullptr, /*B=*/nullptr, ldb, &size));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
/*du=*/nullptr, /*B=*/nullptr, ldb, &size)));
return size;
}

View File

@ -59,6 +59,8 @@ def csr_todense(c, data, indices, indptr, *, shape):
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
@ -86,6 +88,8 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
@ -122,6 +126,8 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
@ -158,6 +164,8 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
@ -187,6 +195,8 @@ def coo_todense(c, data, row, col, *, shape):
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
@ -214,6 +224,8 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
@ -249,6 +261,8 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
@ -285,6 +299,8 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
@ -306,5 +322,7 @@ def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
(_Shape.array_shape(np.dtype(t), (ldb, n), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=cusparse_kernels.build_gtsv2_descriptor(m, n, ldb),
has_side_effect=False)
has_side_effect=False,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0)

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
namespace jax {
@ -77,7 +78,7 @@ class HandlePool {
// Borrows a handle from the pool. If 'stream' is non-null, sets the stream
// associated with the handle.
static Handle Borrow(StreamType stream = nullptr);
static absl::StatusOr<Handle> Borrow(StreamType stream = nullptr);
private:
static HandlePool<HandleType, StreamType>* Instance();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "absl/base/casts.h"
#include "absl/status/statusor.h"
namespace jax {
@ -36,9 +37,10 @@ std::string PackDescriptorAsString(const T& descriptor) {
// Unpacks a descriptor object from a byte string.
template <typename T>
const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {
absl::StatusOr<const T*> UnpackDescriptor(const char* opaque,
std::size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid size for linalg operation descriptor.");
return absl::InternalError("Invalid size for operation descriptor.");
}
return absl::bit_cast<const T*>(opaque);
}

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/rocm_gpu_kernel_helpers.h"
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/rocblas.h"
@ -37,35 +38,36 @@ limitations under the License.
namespace jax {
absl::Status AsStatus(rocblas_status status) {
switch (status) {
case rocblas_status_success:
return absl::OkStatus();
default:
return absl::InternalError(rocblas_status_to_string(status));
}
}
namespace {
namespace py = pybind11;
void ThrowIfErrorStatus(rocblas_status status) {
switch (status) {
case rocblas_status_success:
return;
default:
throw std::runtime_error(rocblas_status_to_string(status));
}
}
using rocBlasHandlePool = HandlePool<rocblas_handle, hipStream_t>;
template <>
/*static*/ rocBlasHandlePool::Handle rocBlasHandlePool::Borrow(
/*static*/ absl::StatusOr<rocBlasHandlePool::Handle> rocBlasHandlePool::Borrow(
hipStream_t stream) {
rocBlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
rocblas_handle handle;
if (pool->handles_[stream].empty()) {
ThrowIfErrorStatus(rocblas_create_handle(&handle));
JAX_RETURN_IF_ERROR(AsStatus(rocblas_create_handle(&handle)))
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
ThrowIfErrorStatus(rocblas_set_stream(handle, stream));
JAX_RETURN_IF_ERROR(AsStatus(rocblas_set_stream(handle, stream)))
}
return rocBlasHandlePool::Handle(pool, handle, stream);
}
@ -148,18 +150,21 @@ std::pair<size_t, py::bytes> BuildTrsmDescriptor(const py::dtype& dtype,
return {lwork, PackDescriptor(desc)};
}
void Trsm(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const TrsmDescriptor& d =
*UnpackDescriptor<TrsmDescriptor>(opaque, opaque_len);
auto handle = rocBlasHandlePool::Borrow(stream);
absl::Status Trsm_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<TrsmDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const TrsmDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// b is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[2] != buffers[1]) {
ThrowIfError(hipMemcpyAsync(buffers[2], buffers[1],
SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[2], buffers[1], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
const int lda = d.side == rocblas_side_left ? d.m : d.n;
const int ldb = d.m;
@ -170,18 +175,18 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
float* a = static_cast<float*>(buffers[0]);
float* b = static_cast<float*>(buffers[2]);
const float alpha = 1.0f;
ThrowIfErrorStatus(rocblas_strsm(handle.get(), d.side, d.uplo, d.trans,
d.diag, d.m, d.n, &alpha,
const_cast<float*>(a), lda, b, ldb));
JAX_RETURN_IF_ERROR(AsStatus(
rocblas_strsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m,
d.n, &alpha, const_cast<float*>(a), lda, b, ldb)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[0]);
double* b = static_cast<double*>(buffers[2]);
const double alpha = 1.0;
ThrowIfErrorStatus(rocblas_dtrsm(handle.get(), d.side, d.uplo, d.trans,
d.diag, d.m, d.n, &alpha,
const_cast<double*>(a), lda, b, ldb));
JAX_RETURN_IF_ERROR(AsStatus(
rocblas_dtrsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m,
d.n, &alpha, const_cast<double*>(a), lda, b, ldb)))
break;
}
case Type::C64: {
@ -190,9 +195,9 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
rocblas_float_complex* b =
static_cast<rocblas_float_complex*>(buffers[2]);
const rocblas_float_complex alpha = {1.0f, 0.0f};
ThrowIfErrorStatus(rocblas_ctrsm(
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_float_complex*>(a), lda, b, ldb));
const_cast<rocblas_float_complex*>(a), lda, b, ldb)))
break;
}
case Type::C128: {
@ -200,10 +205,10 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
static_cast<rocblas_double_complex*>(buffers[0]);
rocblas_double_complex* b =
static_cast<rocblas_double_complex*>(buffers[2]);
const rocblas_double_complex alpha = {1.0d, 0.0d};
ThrowIfErrorStatus(rocblas_ztrsm(
const rocblas_double_complex alpha = {1.0f, 0.0f};
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_double_complex*>(a), lda, b, ldb));
const_cast<rocblas_double_complex*>(a), lda, b, ldb)))
break;
}
}
@ -211,33 +216,35 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
auto a_batch_host =
MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
SizeOfType(d.type) * lda * lda);
JAX_RETURN_IF_ERROR(a_batch_host.status());
auto b_batch_host =
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(b_batch_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
ThrowIfError(hipStreamSynchronize(stream));
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
const float alpha = 1.0f;
ThrowIfErrorStatus(rocblas_strsm_batched(
JAX_RETURN_IF_ERROR(AsStatus(rocblas_strsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch));
d.batch)))
break;
}
case Type::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
const double alpha = 1.0;
ThrowIfErrorStatus(rocblas_dtrsm_batched(
JAX_RETURN_IF_ERROR(AsStatus(rocblas_dtrsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch));
d.batch)))
break;
}
case Type::C64: {
@ -246,10 +253,10 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
rocblas_float_complex** b_batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[4]);
const rocblas_float_complex alpha = {1.0f, 0.0f};
ThrowIfErrorStatus(rocblas_ctrsm_batched(
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_float_complex**>(a_batch_ptrs), lda,
b_batch_ptrs, ldb, d.batch));
b_batch_ptrs, ldb, d.batch)))
break;
}
case Type::C128: {
@ -257,15 +264,25 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
static_cast<rocblas_double_complex**>(buffers[3]);
rocblas_double_complex** b_batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[4]);
const rocblas_double_complex alpha = {1.0d, 0.0d};
ThrowIfErrorStatus(rocblas_ztrsm_batched(
const rocblas_double_complex alpha = {1.0f, 0.0f};
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_double_complex**>(a_batch_ptrs), lda,
b_batch_ptrs, ldb, d.batch));
b_batch_ptrs, ldb, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Trsm(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Trsm_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
//##########################
@ -290,17 +307,20 @@ std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
return {lwork, PackDescriptor(PotrfDescriptor{type, uplo, b, n})};
}
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const PotrfDescriptor& d =
*UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
auto handle = rocBlasHandlePool::Borrow(stream);
absl::Status Potrf_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const PotrfDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)))
}
int* info = static_cast<int*>(buffers[2]);
@ -308,28 +328,28 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque,
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_spotrf(handle.get(), d.uplo, d.n, a, d.n, info));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_spotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_dpotrf(handle.get(), d.uplo, d.n, a, d.n, info));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_dpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_cpotrf(handle.get(), d.uplo, d.n, a, d.n, info));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_cpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_zpotrf(handle.get(), d.uplo, d.n, a, d.n, info));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_zpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
}
@ -337,40 +357,51 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque,
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
ThrowIfError(hipStreamSynchronize(stream));
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
ThrowIfErrorStatus(rocsolver_spotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_spotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
case Type::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
ThrowIfErrorStatus(rocsolver_dpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
case Type::C64: {
rocblas_float_complex** a_batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[3]);
ThrowIfErrorStatus(rocsolver_cpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
case Type::C128: {
rocblas_double_complex** a_batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[3]);
ThrowIfErrorStatus(rocsolver_zpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Potrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// getrf: LU decomposition
@ -389,18 +420,21 @@ std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})};
}
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const GetrfDescriptor& d =
*UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
auto handle = rocBlasHandlePool::Borrow(stream);
absl::Status Getrf_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GetrfDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
int* ipiv = static_cast<int*>(buffers[2]);
@ -410,28 +444,28 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque,
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_sgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_dgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_cgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[1]);
ThrowIfErrorStatus(
rocsolver_zgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
}
@ -439,44 +473,55 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque,
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
ThrowIfError(hipStreamSynchronize(stream));
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** batch_ptrs = static_cast<float**>(buffers[4]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch));
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
case Type::F64: {
double** batch_ptrs = static_cast<double**>(buffers[4]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch));
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
case Type::C64: {
rocblas_float_complex** batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[4]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch));
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
case Type::C128: {
rocblas_double_complex** batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[4]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch));
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Getrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// geqrf: QR decomposition
@ -494,18 +539,21 @@ std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n})};
}
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const GeqrfDescriptor& d =
*UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
auto handle = rocBlasHandlePool::Borrow(stream);
absl::Status Geqrf_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GeqrfDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
// here tau is tau
@ -515,15 +563,15 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float* tau = static_cast<float*>(buffers[2]);
ThrowIfErrorStatus(
rocsolver_sgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_sgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double* tau = static_cast<double*>(buffers[2]);
ThrowIfErrorStatus(
rocsolver_dgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_dgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
case Type::C64: {
@ -531,8 +579,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
static_cast<rocblas_float_complex*>(buffers[1]);
rocblas_float_complex* tau =
static_cast<rocblas_float_complex*>(buffers[2]);
ThrowIfErrorStatus(
rocsolver_cgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_cgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
case Type::C128: {
@ -540,8 +588,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
static_cast<rocblas_double_complex*>(buffers[1]);
rocblas_double_complex* tau =
static_cast<rocblas_double_complex*>(buffers[2]);
ThrowIfErrorStatus(
rocsolver_zgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_zgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
}
@ -549,26 +597,27 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
ThrowIfError(hipStreamSynchronize(stream));
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** batch_ptrs = static_cast<float**>(buffers[3]);
float* tau = static_cast<float*>(buffers[2]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch));
tau, std::min(d.m, d.n), d.batch)))
break;
}
case Type::F64: {
double** batch_ptrs = static_cast<double**>(buffers[3]);
double* tau = static_cast<double*>(buffers[2]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch));
tau, std::min(d.m, d.n), d.batch)))
break;
}
case Type::C64: {
@ -576,9 +625,9 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
static_cast<rocblas_float_complex**>(buffers[3]);
rocblas_float_complex* tau =
static_cast<rocblas_float_complex*>(buffers[2]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch));
tau, std::min(d.m, d.n), d.batch)))
break;
}
case Type::C128: {
@ -586,13 +635,23 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
static_cast<rocblas_double_complex**>(buffers[3]);
rocblas_double_complex* tau =
static_cast<rocblas_double_complex*>(buffers[2]);
ThrowIfErrorStatus(
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch));
tau, std::min(d.m, d.n), d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// orgqr/ungqr: apply elementary Householder transformations
@ -608,18 +667,21 @@ std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k})};
}
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const OrgqrDescriptor& d =
*UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
auto handle = rocBlasHandlePool::Borrow(stream);
absl::Status Orgqr_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const OrgqrDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[2] != buffers[0]) {
ThrowIfError(hipMemcpyAsync(buffers[2], buffers[0],
SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[2], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
switch (d.type) {
@ -629,8 +691,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
float* a = static_cast<float*>(buffers[2]);
float* tau = static_cast<float*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(
rocsolver_sorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
@ -640,8 +702,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
double* a = static_cast<double*>(buffers[2]);
double* tau = static_cast<double*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(
rocsolver_dorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
@ -656,8 +718,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
rocblas_float_complex* tau =
static_cast<rocblas_float_complex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(
rocsolver_cungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
@ -669,14 +731,24 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
rocblas_double_complex* tau =
static_cast<rocblas_double_complex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(
rocsolver_zungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
break;
}
}
return absl::OkStatus();
}
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
@ -715,18 +787,21 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
return {lwork, PackDescriptor(GesvdDescriptor{type, b, m, n, jobu, jobvt})};
}
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const GesvdDescriptor& d =
*UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
auto handle = rocBlasHandlePool::Borrow(stream);
absl::Status Gesvd_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GesvdDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream));
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
int* info = static_cast<int*>(buffers[5]);
@ -743,9 +818,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
float* u = static_cast<float*>(buffers[3]);
float* vt = static_cast<float*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_sgesvd(handle.get(), d.jobu, d.jobvt, d.m,
d.n, a, lda, s, u, ldu, vt, ldv, e,
rocblas_inplace, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
case Type::F64: {
@ -754,9 +829,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
double* u = static_cast<double*>(buffers[3]);
double* vt = static_cast<double*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_dgesvd(handle.get(), d.jobu, d.jobvt, d.m,
d.n, a, lda, s, u, ldu, vt, ldv, e,
rocblas_inplace, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
case Type::C64: {
@ -768,9 +843,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
rocblas_float_complex* vt =
static_cast<rocblas_float_complex*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_cgesvd(handle.get(), d.jobu, d.jobvt, d.m,
d.n, a, lda, s, u, ldu, vt, ldv, e,
rocblas_inplace, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
case Type::C128: {
@ -782,9 +857,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
rocblas_double_complex* vt =
static_cast<rocblas_double_complex*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_zgesvd(handle.get(), d.jobu, d.jobvt, d.m,
d.n, a, lda, s, u, ldu, vt, ldv, e,
rocblas_inplace, info));
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
}
@ -797,10 +872,11 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[7], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
ThrowIfError(hipStreamSynchronize(stream));
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
@ -809,10 +885,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
float* u = static_cast<float*>(buffers[3]);
float* vt = static_cast<float*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_sgesvd_batched(
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_sgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch));
rocblas_inplace, info, d.batch)))
break;
}
case Type::F64: {
@ -821,10 +897,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
double* u = static_cast<double*>(buffers[3]);
double* vt = static_cast<double*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_dgesvd_batched(
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch));
rocblas_inplace, info, d.batch)))
break;
}
case Type::C64: {
@ -836,10 +912,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
rocblas_float_complex* vt =
static_cast<rocblas_float_complex*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_cgesvd_batched(
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch));
rocblas_inplace, info, d.batch)))
break;
}
case Type::C128: {
@ -851,14 +927,24 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
rocblas_double_complex* vt =
static_cast<rocblas_double_complex*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
ThrowIfErrorStatus(rocsolver_zgesvd_batched(
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch));
rocblas_inplace, info, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
s.error_message().length());
}
}
// Singular value decomposition using Jacobi algorithm: gesvdj

View File

@ -22,24 +22,26 @@ limitations under the License.
namespace jax {
void ThrowIfError(hipError_t error) {
absl::Status AsStatus(hipError_t error) {
if (error != hipSuccess) {
throw std::runtime_error(
return absl::InternalError(
absl::StrCat("ROCm operation failed: ", hipGetErrorString(error)));
}
return absl::OkStatus();
}
std::unique_ptr<void* []> MakeBatchPointers(hipStream_t stream, void* buffer,
void* dev_ptrs, int batch,
int batch_elem_size) {
absl::StatusOr<std::unique_ptr<void* []>> MakeBatchPointers(
hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
host_ptrs[i] = ptr;
ptr += batch_elem_size;
}
ThrowIfError(hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
hipMemcpyHostToDevice, stream));
JAX_RETURN_IF_ERROR(
AsStatus(hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
hipMemcpyHostToDevice, stream)));
return host_ptrs;
}
} // namespace jax

View File

@ -19,18 +19,28 @@ limitations under the License.
#include <memory>
#include "rocm/include/hip/hip_runtime_api.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#define JAX_RETURN_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) return s___; \
}
namespace jax {
void ThrowIfError(hipError_t error);
absl::Status AsStatus(hipError_t error);
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
// completes.
std::unique_ptr<void*[]> MakeBatchPointers(hipStream_t stream, void* buffer,
void* dev_ptrs, int batch,
int batch_elem_size);
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(hipStream_t stream,
void* buffer,
void* dev_ptrs,
int batch,
int batch_elem_size);
} // namespace jax

View File

@ -96,7 +96,9 @@ def trsm(c,
_Shape.array_shape(dtype, a_shape.dimensions(), layout), # buffers[0] (a)
_Shape.array_shape(dtype, b_shape.dimensions(), layout), # buffers[1] (b, IN)
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0)
@ -133,7 +135,9 @@ def potrf(c, a, lower):
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[0] (a, IN)
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)
@ -171,7 +175,9 @@ def getrf(c, a):
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[0] (a, IN)
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
@ -208,7 +214,9 @@ def geqrf(c, a):
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[0] (a, IN)
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
# rocsolver geqrf does not return info
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), None)
@ -247,7 +255,9 @@ def orgqr(c, a, tau):
_Shape.array_shape(dtype, batch_dims + (k,),
tuple(range(num_bd, -1, -1))), # buffers[1] (tau IN)
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), None) # ROCSolver orgqr does not return info
@ -303,7 +313,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n),
matrix_layout), # buffers[0] (a, IN)
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
vt = _ops.GetTupleElement(out, 2)
u = _ops.GetTupleElement(out, 3)
@ -338,7 +350,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n),
matrix_layout), # buffers[0] (a, IN)
),
opaque=opaque)
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
u = _ops.GetTupleElement(out, 2)
vt = _ops.GetTupleElement(out, 3)