mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use the new "custom call status" facility to report errors in jaxlib
PiperOrigin-RevId: 389734200
This commit is contained in:
parent
f04464d210
commit
c368969955
15
jaxlib/BUILD
15
jaxlib/BUILD
@ -54,6 +54,7 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
@ -67,6 +68,7 @@ cc_library(
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
@ -82,8 +84,9 @@ cc_library(
|
||||
deps = [
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cusolver_lib",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
@ -102,6 +105,8 @@ cc_library(
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
],
|
||||
@ -159,6 +164,7 @@ pybind_extension(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":handle_pool",
|
||||
":kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cublas_lib",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -189,6 +195,7 @@ pybind_extension(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":handle_pool",
|
||||
":kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cusolver_lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -218,6 +225,7 @@ pybind_extension(
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":handle_pool",
|
||||
":kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cudart_stub",
|
||||
"@org_tensorflow//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -241,6 +249,8 @@ cuda_library(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -270,6 +280,8 @@ cuda_library(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -306,6 +318,7 @@ pybind_extension(
|
||||
":handle_pool",
|
||||
":kernel_pybind11_helpers",
|
||||
":rocm_gpu_kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
|
105
jaxlib/cublas.cc
105
jaxlib/cublas.cc
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
@ -41,18 +42,19 @@ namespace py = pybind11;
|
||||
using BlasHandlePool = HandlePool<cublasHandle_t, cudaStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ BlasHandlePool::Handle BlasHandlePool::Borrow(cudaStream_t stream) {
|
||||
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
|
||||
cudaStream_t stream) {
|
||||
BlasHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cublasHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_THROW_IF_ERROR(cublasCreate(&handle));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_THROW_IF_ERROR(cublasSetStream(handle, stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
@ -122,26 +124,31 @@ std::pair<size_t, py::bytes> BuildTrsmBatchedDescriptor(
|
||||
return {size, PackDescriptor(desc)};
|
||||
}
|
||||
|
||||
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const TrsmBatchedDescriptor& d =
|
||||
*UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
|
||||
auto handle = BlasHandlePool::Borrow(stream);
|
||||
absl::Status TrsmBatched_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const TrsmBatchedDescriptor& d = **s;
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[1]) {
|
||||
JAX_THROW_IF_ERROR(cudaMemcpyAsync(buffers[2], buffers[1],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[2], buffers[1], SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
const int lda = d.side == CUBLAS_SIDE_LEFT ? d.m : d.n;
|
||||
const int ldb = d.m;
|
||||
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
|
||||
SizeOfType(d.type) * lda * lda);
|
||||
JAX_RETURN_IF_ERROR(a_batch_host.status());
|
||||
auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(b_batch_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_THROW_IF_ERROR(cudaStreamSynchronize(stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[0]);
|
||||
@ -150,10 +157,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
// NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault.
|
||||
const float alpha = 1.0f;
|
||||
JAX_THROW_IF_ERROR(cublasStrsmBatched(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasStrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch));
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
@ -162,10 +169,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
const double alpha = 1.0;
|
||||
JAX_THROW_IF_ERROR(cublasDtrsmBatched(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch));
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
@ -174,10 +181,10 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
cuComplex** a_batch_ptrs = static_cast<cuComplex**>(buffers[3]);
|
||||
cuComplex** b_batch_ptrs = static_cast<cuComplex**>(buffers[4]);
|
||||
const cuComplex alpha = make_cuComplex(1.0f, 0.0f);
|
||||
JAX_THROW_IF_ERROR(cublasCtrsmBatched(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const cuComplex**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch));
|
||||
d.batch)));
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
@ -188,13 +195,23 @@ void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
cuDoubleComplex** b_batch_ptrs =
|
||||
static_cast<cuDoubleComplex**>(buffers[4]);
|
||||
const cuDoubleComplex alpha = make_cuDoubleComplex(1.0f, 0.0f);
|
||||
JAX_THROW_IF_ERROR(cublasZtrsmBatched(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZtrsmBatched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<const cuDoubleComplex**>(a_batch_ptrs), lda, b_batch_ptrs,
|
||||
ldb, d.batch));
|
||||
ldb, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void TrsmBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
@ -212,55 +229,69 @@ std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
|
||||
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
|
||||
}
|
||||
|
||||
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const GetrfBatchedDescriptor& d =
|
||||
*UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
|
||||
auto handle = BlasHandlePool::Borrow(stream);
|
||||
absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GetrfBatchedDescriptor& d = **s;
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_THROW_IF_ERROR(cudaMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.n * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
|
||||
SizeOfType(d.type) * d.n * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_THROW_IF_ERROR(cudaStreamSynchronize(stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float** batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
JAX_THROW_IF_ERROR(cublasSgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
|
||||
ipiv, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double** batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
JAX_THROW_IF_ERROR(cublasDgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
|
||||
ipiv, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
|
||||
cuComplex** batch_ptrs = static_cast<cuComplex**>(buffers[4]);
|
||||
JAX_THROW_IF_ERROR(cublasCgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
|
||||
ipiv, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
|
||||
cuDoubleComplex** batch_ptrs = static_cast<cuDoubleComplex**>(buffers[4]);
|
||||
JAX_THROW_IF_ERROR(cublasZgetrfBatched(handle.get(), d.n, batch_ptrs, d.n,
|
||||
ipiv, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
|
@ -23,15 +23,13 @@ limitations under the License.
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
std::string ErrorToString(cudaError_t error) {
|
||||
return cudaGetErrorString(error);
|
||||
}
|
||||
std::string ErrorString(cudaError_t error) { return cudaGetErrorString(error); }
|
||||
|
||||
std::string ErrorToString(cusparseStatus_t status) {
|
||||
std::string ErrorString(cusparseStatus_t status) {
|
||||
return cusparseGetErrorString(status);
|
||||
}
|
||||
|
||||
std::string ErrorToString(cusolverStatus_t status) {
|
||||
std::string ErrorString(cusolverStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSOLVER_STATUS_SUCCESS:
|
||||
return "cuSolver success.";
|
||||
@ -62,7 +60,7 @@ std::string ErrorToString(cusolverStatus_t status) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string ErrorToString(cublasStatus_t status) {
|
||||
std::string ErrorString(cublasStatus_t status) {
|
||||
switch (status) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "cuBlas success";
|
||||
@ -90,47 +88,53 @@ std::string ErrorToString(cublasStatus_t status) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ThrowError(T status, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
throw std::runtime_error(absl::StrFormat("%s:%d: operation %s failed: %s",
|
||||
file, line, expr,
|
||||
ErrorToString(status)));
|
||||
std::string ErrorString(T status, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
return absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr,
|
||||
ErrorString(status));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ThrowIfError(cudaError_t error, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
if (error != cudaSuccess) ThrowError(error, file, line, expr);
|
||||
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
if (error != cudaSuccess)
|
||||
return absl::InternalError(ErrorString(error, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ThrowIfError(cusolverStatus_t status, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
if (status != CUSOLVER_STATUS_SUCCESS) ThrowError(status, file, line, expr);
|
||||
absl::Status AsStatus(cusolverStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != CUSOLVER_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ThrowIfError(cusparseStatus_t status, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
if (status != CUSPARSE_STATUS_SUCCESS) ThrowError(status, file, line, expr);
|
||||
absl::Status AsStatus(cusparseStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != CUSPARSE_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ThrowIfError(cublasStatus_t status, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
if (status != CUBLAS_STATUS_SUCCESS) ThrowError(status, file, line, expr);
|
||||
absl::Status AsStatus(cublasStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::unique_ptr<void* []> MakeBatchPointers(cudaStream_t stream, void* buffer,
|
||||
void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
absl::StatusOr<std::unique_ptr<void* []>> MakeBatchPointers(
|
||||
cudaStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
char* ptr = static_cast<char*>(buffer);
|
||||
auto host_ptrs = absl::make_unique<void*[]>(batch);
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
host_ptrs[i] = ptr;
|
||||
ptr += batch_elem_size;
|
||||
}
|
||||
JAX_THROW_IF_ERROR(cudaMemcpyAsync(dev_ptrs, host_ptrs.get(),
|
||||
sizeof(void*) * batch,
|
||||
cudaMemcpyHostToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudaMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
|
||||
cudaMemcpyHostToDevice, stream)));
|
||||
return host_ptrs;
|
||||
}
|
||||
} // namespace jax
|
||||
|
||||
|
@ -18,33 +18,48 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
#define JAX_THROW_IF_ERROR(expr) \
|
||||
jax::ThrowIfError(expr, __FILE__, __LINE__, #expr)
|
||||
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
|
||||
|
||||
#define JAX_THROW_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto s___ = (expr); \
|
||||
if (!s___.ok()) throw std::runtime_error(s___.error_message()); \
|
||||
}
|
||||
|
||||
#define JAX_RETURN_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto s___ = (expr); \
|
||||
if (!s___.ok()) return s___; \
|
||||
}
|
||||
|
||||
namespace jax {
|
||||
|
||||
// Used via JAX_THROW_IF_ERROR(expr) macro.
|
||||
void ThrowIfError(cudaError_t error, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
void ThrowIfError(cusolverStatus_t status, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
void ThrowIfError(cusparseStatus_t status, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
void ThrowIfError(cublasStatus_t status, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
// Used via JAX_AS_STATUS(expr) macro.
|
||||
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
absl::Status AsStatus(cusolverStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(cusparseStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(cublasStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
|
||||
// Builds an array of pointers to each array in a batch, in device memory.
|
||||
// Caution: the return value must be kept alive (e.g., via a stream
|
||||
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
|
||||
// completes.
|
||||
std::unique_ptr<void*[]> MakeBatchPointers(cudaStream_t stream, void* buffer,
|
||||
void* dev_ptrs, int batch,
|
||||
int batch_elem_size);
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(cudaStream_t stream,
|
||||
void* buffer,
|
||||
void* dev_ptrs,
|
||||
int batch,
|
||||
int batch_elem_size);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
|
@ -58,4 +58,6 @@ def lu_pivots_to_permutation(c, pivots, *, permutation_size):
|
||||
operands=(pivots,),
|
||||
shape_with_layout=permutations_shape_with_layout,
|
||||
operand_shapes_with_layout=(pivots_shape_with_layout,),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
@ -73,13 +74,16 @@ std::string BuildCudaLuPivotsToPermutationDescriptor(
|
||||
batch_size, pivot_size, permutation_size});
|
||||
}
|
||||
|
||||
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
absl::Status CudaLuPivotsToPermutation_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
const std::int32_t* pivots =
|
||||
reinterpret_cast<const std::int32_t*>(buffers[0]);
|
||||
std::int32_t* permutation_out = reinterpret_cast<std::int32_t*>(buffers[1]);
|
||||
const auto& descriptor =
|
||||
*UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
|
||||
auto s =
|
||||
UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const auto& descriptor = **s;
|
||||
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim = std::min<std::int64_t>(
|
||||
@ -89,7 +93,18 @@ void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
|
||||
/*dynamic_shared_mem_bytes=*/0, stream>>>(
|
||||
pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size,
|
||||
descriptor.permutation_size);
|
||||
JAX_THROW_IF_ERROR(cudaGetLastError());
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len,
|
||||
XlaCustomCallStatus* status) {
|
||||
auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
@ -28,7 +29,8 @@ std::string BuildCudaLuPivotsToPermutationDescriptor(
|
||||
std::int32_t permutation_size);
|
||||
|
||||
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len);
|
||||
const char* opaque, std::size_t opaque_len,
|
||||
XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
|
@ -56,4 +56,6 @@ def threefry2x32(c, keys, data):
|
||||
operands=(keys[0], keys[1], data[0], data[1]),
|
||||
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),
|
||||
operand_shapes_with_layout=(shape,) * 4,
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
|
@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda_prng_kernels.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
|
||||
#include "jaxlib/cuda_prng_kernels.h"
|
||||
#include "jaxlib/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
@ -106,8 +108,8 @@ std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
|
||||
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
|
||||
}
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
absl::Status CudaThreeFry2x32_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
std::array<const std::uint32_t*, 2> keys;
|
||||
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
|
||||
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
|
||||
@ -117,15 +119,26 @@ void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::array<std::uint32_t*, 2> out;
|
||||
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
|
||||
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
|
||||
const auto& descriptor =
|
||||
*UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
|
||||
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const auto& descriptor = **s;
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim =
|
||||
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
|
||||
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
|
||||
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
|
||||
out[1], descriptor.n);
|
||||
JAX_THROW_IF_ERROR(cudaGetLastError());
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
@ -20,13 +20,14 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n);
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len);
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -97,7 +97,9 @@ def trsm(c, a, b, left_side=False, lower=False, trans_a=False, conj_a=False,
|
||||
_Shape.array_shape(dtype, a_shape.dimensions(), layout),
|
||||
_Shape.array_shape(dtype, b_shape.dimensions(), layout),
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
@ -131,7 +133,9 @@ def potrf(c, a, lower):
|
||||
operand_shapes_with_layout=(_Shape.array_shape(
|
||||
dtype, batch_dims + (n, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)
|
||||
|
||||
|
||||
@ -175,7 +179,9 @@ def getrf(c, a):
|
||||
operand_shapes_with_layout=(_Shape.array_shape(
|
||||
dtype, batch_dims + (m, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
||||
_ops.GetTupleElement(out, 2))
|
||||
|
||||
@ -213,7 +219,9 @@ def geqrf(c, a):
|
||||
operand_shapes_with_layout=(_Shape.array_shape(
|
||||
dtype, batch_dims + (m, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
||||
_ops.GetTupleElement(out, 2))
|
||||
|
||||
@ -257,7 +265,9 @@ def orgqr(c, a, tau):
|
||||
dtype, batch_dims + (k,),
|
||||
tuple(range(num_bd, -1, -1))),
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1))
|
||||
|
||||
|
||||
@ -302,7 +312,9 @@ def syevd(c, a, lower=False):
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(dtype, dims, layout),
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
||||
_ops.GetTupleElement(out, 2))
|
||||
|
||||
@ -342,7 +354,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
s = _ops.GetTupleElement(out, 1)
|
||||
u = _ops.GetTupleElement(out, 2)
|
||||
v = _ops.GetTupleElement(out, 3)
|
||||
@ -371,7 +385,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
s = _ops.GetTupleElement(out, 1)
|
||||
vt = _ops.GetTupleElement(out, 2)
|
||||
u = _ops.GetTupleElement(out, 3)
|
||||
@ -398,7 +414,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
s = _ops.GetTupleElement(out, 1)
|
||||
u = _ops.GetTupleElement(out, 2)
|
||||
vt = _ops.GetTupleElement(out, 3)
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
|
||||
@ -148,19 +149,19 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
using SparseHandlePool = HandlePool<cusparseHandle_t, cudaStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ SparseHandlePool::Handle SparseHandlePool::Borrow(
|
||||
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
|
||||
cudaStream_t stream) {
|
||||
SparseHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cusparseHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_THROW_IF_ERROR(cusparseCreate(&handle));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_THROW_IF_ERROR(cusparseSetStream(handle, stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
@ -246,7 +247,9 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
||||
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
@ -258,47 +261,60 @@ std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(cusparseSparseToDense_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
|
||||
d.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrFromDense: Convert dense matrix to CSR matrix
|
||||
@ -307,7 +323,9 @@ void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
@ -317,48 +335,61 @@ std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(cusparseDenseToSparse_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CsrFromDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO,
|
||||
d.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseDenseToSparse_analysis(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
JAX_THROW_IF_ERROR(cusparseDenseToSparse_convert(
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrMatvec: Product of CSR matrix and dense vector.
|
||||
@ -374,7 +405,9 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseVecDescriptor x =
|
||||
@ -391,30 +424,35 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(y.type);
|
||||
CudaConst beta = CudaZero(y.type);
|
||||
JAX_THROW_IF_ERROR(cusparseSpMV_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
|
||||
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CsrMatvecDescriptor& d =
|
||||
*UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CsrMatvec_(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CsrMatvecDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* csr_values = buffers[0];
|
||||
void* csr_col_ind = buffers[1];
|
||||
@ -434,20 +472,32 @@ void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
|
||||
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
|
||||
&beta, vec_y, d.y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, buf));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrMatmat: Product of CSR matrix and dense matrix.
|
||||
@ -463,7 +513,9 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseMatDescriptor B =
|
||||
@ -480,32 +532,37 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCsr(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(C.type);
|
||||
CudaConst beta = CudaZero(C.type);
|
||||
JAX_THROW_IF_ERROR(cusparseSpMM_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CsrMatmatDescriptor& d =
|
||||
*UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CsrMatmat_(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CsrMatmatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* csr_values = buffers[0];
|
||||
void* csr_col_ind = buffers[1];
|
||||
@ -525,23 +582,33 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
|
||||
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseSpMM(
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
@ -550,7 +617,9 @@ void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
@ -561,46 +630,60 @@ std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(cusparseSparseToDense_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CooToDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[1],
|
||||
/*cooColInd=*/buffers[2],
|
||||
/*cooValues=*/buffers[0], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[1],
|
||||
/*cooColInd=*/buffers[2],
|
||||
/*cooValues=*/buffers[0], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooFromDense: Convert dense matrix to COO matrix
|
||||
@ -609,7 +692,9 @@ void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
|
||||
@ -619,47 +704,61 @@ std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty,
|
||||
empty, empty, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(cusparseDenseToSparse_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size));
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SparseMatDescriptor& d =
|
||||
*UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CooFromDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0],
|
||||
d.value_type, CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[2],
|
||||
/*cooColInd=*/buffers[3],
|
||||
/*cooValues=*/buffers[1], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseDenseToSparse_analysis(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[2],
|
||||
/*cooColInd=*/buffers[3],
|
||||
/*cooValues=*/buffers[1], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
JAX_THROW_IF_ERROR(cusparseDenseToSparse_convert(
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4]));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_b));
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooMatvec: Product of COO matrix and dense vector.
|
||||
@ -675,7 +774,9 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseVecDescriptor x =
|
||||
@ -692,30 +793,35 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, x.size, empty, x.type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, y.size, empty, y.type));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(y.type);
|
||||
CudaConst beta = CudaZero(y.type);
|
||||
JAX_THROW_IF_ERROR(cusparseSpMV_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size));
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
|
||||
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CooMatvecDescriptor& d =
|
||||
*UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CooMatvec_(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CooMatvecDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* coo_values = buffers[0];
|
||||
void* coo_row_ind = buffers[1];
|
||||
@ -735,19 +841,31 @@ void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type));
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x,
|
||||
&beta, vec_y, d.y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, buf));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_x));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnVec(vec_y));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooMatmat: Product of COO matrix and dense matrix.
|
||||
@ -763,7 +881,9 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
|
||||
DenseMatDescriptor B =
|
||||
@ -780,32 +900,37 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty,
|
||||
empty, empty, A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(C.type);
|
||||
CudaConst beta = CudaZero(C.type);
|
||||
JAX_THROW_IF_ERROR(cusparseSpMM_bufferSize(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size));
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const CooMatmatDescriptor& d =
|
||||
*UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
|
||||
auto handle = SparseHandlePool::Borrow(stream);
|
||||
absl::Status CooMatmat_(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CooMatmatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* coo_values = buffers[0];
|
||||
void* coo_row_ind = buffers[1];
|
||||
@ -825,22 +950,32 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseCreateCoo(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseCreateDnMat(&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type,
|
||||
CUSPARSE_ORDER_ROW));
|
||||
JAX_THROW_IF_ERROR(cusparseSpMM(
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf));
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_THROW_IF_ERROR(cusparseDestroySpMat(mat_a));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_b));
|
||||
JAX_THROW_IF_ERROR(cusparseDestroyDnMat(mat_c));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
#endif // if JAX_CUSPARSE_11030
|
||||
|
||||
@ -853,12 +988,15 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
void gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
const Gtsv2Descriptor& descriptor =
|
||||
*UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
|
||||
auto s = UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const Gtsv2Descriptor& descriptor = **s;
|
||||
int m = descriptor.m;
|
||||
int n = descriptor.n;
|
||||
int ldb = descriptor.ldb;
|
||||
@ -878,31 +1016,42 @@ void gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
|
||||
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
|
||||
if (X != B) {
|
||||
size_t B_bytes = ldb * n * sizeof(T);
|
||||
JAX_THROW_IF_ERROR(
|
||||
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
JAX_THROW_IF_ERROR(
|
||||
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
template <typename F>
|
||||
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
size_t size;
|
||||
JAX_THROW_IF_ERROR(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
|
||||
/*du=*/nullptr, /*B=*/nullptr, ldb, &size));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
|
||||
/*du=*/nullptr, /*B=*/nullptr, ldb, &size)));
|
||||
return size;
|
||||
}
|
||||
|
||||
|
@ -59,6 +59,8 @@ def csr_todense(c, data, indices, indptr, *, shape):
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
@ -86,6 +88,8 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
@ -122,6 +126,8 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
|
||||
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
@ -158,6 +164,8 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
|
||||
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
@ -187,6 +195,8 @@ def coo_todense(c, data, row, col, *, shape):
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
@ -214,6 +224,8 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
|
||||
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
|
||||
)),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
|
||||
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
|
||||
@ -249,6 +261,8 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
|
||||
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
@ -285,6 +299,8 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
|
||||
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
@ -306,5 +322,7 @@ def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
|
||||
(_Shape.array_shape(np.dtype(t), (ldb, n), (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
|
||||
opaque=cusparse_kernels.build_gtsv2_descriptor(m, n, ldb),
|
||||
has_side_effect=False)
|
||||
has_side_effect=False,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
namespace jax {
|
||||
@ -77,7 +78,7 @@ class HandlePool {
|
||||
|
||||
// Borrows a handle from the pool. If 'stream' is non-null, sets the stream
|
||||
// associated with the handle.
|
||||
static Handle Borrow(StreamType stream = nullptr);
|
||||
static absl::StatusOr<Handle> Borrow(StreamType stream = nullptr);
|
||||
|
||||
private:
|
||||
static HandlePool<HandleType, StreamType>* Instance();
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
@ -36,9 +37,10 @@ std::string PackDescriptorAsString(const T& descriptor) {
|
||||
|
||||
// Unpacks a descriptor object from a byte string.
|
||||
template <typename T>
|
||||
const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {
|
||||
absl::StatusOr<const T*> UnpackDescriptor(const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
if (opaque_len != sizeof(T)) {
|
||||
throw std::runtime_error("Invalid size for linalg operation descriptor.");
|
||||
return absl::InternalError("Invalid size for operation descriptor.");
|
||||
}
|
||||
return absl::bit_cast<const T*>(opaque);
|
||||
}
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "jaxlib/rocm_gpu_kernel_helpers.h"
|
||||
#include "third_party/tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/rocblas.h"
|
||||
@ -37,35 +38,36 @@ limitations under the License.
|
||||
|
||||
|
||||
namespace jax {
|
||||
|
||||
absl::Status AsStatus(rocblas_status status) {
|
||||
switch (status) {
|
||||
case rocblas_status_success:
|
||||
return absl::OkStatus();
|
||||
default:
|
||||
return absl::InternalError(rocblas_status_to_string(status));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void ThrowIfErrorStatus(rocblas_status status) {
|
||||
switch (status) {
|
||||
case rocblas_status_success:
|
||||
return;
|
||||
default:
|
||||
throw std::runtime_error(rocblas_status_to_string(status));
|
||||
}
|
||||
}
|
||||
|
||||
using rocBlasHandlePool = HandlePool<rocblas_handle, hipStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ rocBlasHandlePool::Handle rocBlasHandlePool::Borrow(
|
||||
/*static*/ absl::StatusOr<rocBlasHandlePool::Handle> rocBlasHandlePool::Borrow(
|
||||
hipStream_t stream) {
|
||||
rocBlasHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
rocblas_handle handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
ThrowIfErrorStatus(rocblas_create_handle(&handle));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_create_handle(&handle)))
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
ThrowIfErrorStatus(rocblas_set_stream(handle, stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_set_stream(handle, stream)))
|
||||
}
|
||||
return rocBlasHandlePool::Handle(pool, handle, stream);
|
||||
}
|
||||
@ -148,18 +150,21 @@ std::pair<size_t, py::bytes> BuildTrsmDescriptor(const py::dtype& dtype,
|
||||
return {lwork, PackDescriptor(desc)};
|
||||
}
|
||||
|
||||
void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const TrsmDescriptor& d =
|
||||
*UnpackDescriptor<TrsmDescriptor>(opaque, opaque_len);
|
||||
auto handle = rocBlasHandlePool::Borrow(stream);
|
||||
absl::Status Trsm_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<TrsmDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const TrsmDescriptor& d = **s;
|
||||
auto h = rocBlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
// b is INOUT, so we copy the input to the output and use that if they are not
|
||||
// already the same
|
||||
if (buffers[2] != buffers[1]) {
|
||||
ThrowIfError(hipMemcpyAsync(buffers[2], buffers[1],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
|
||||
buffers[2], buffers[1], SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)))
|
||||
}
|
||||
const int lda = d.side == rocblas_side_left ? d.m : d.n;
|
||||
const int ldb = d.m;
|
||||
@ -170,18 +175,18 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
float* a = static_cast<float*>(buffers[0]);
|
||||
float* b = static_cast<float*>(buffers[2]);
|
||||
const float alpha = 1.0f;
|
||||
ThrowIfErrorStatus(rocblas_strsm(handle.get(), d.side, d.uplo, d.trans,
|
||||
d.diag, d.m, d.n, &alpha,
|
||||
const_cast<float*>(a), lda, b, ldb));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocblas_strsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m,
|
||||
d.n, &alpha, const_cast<float*>(a), lda, b, ldb)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[0]);
|
||||
double* b = static_cast<double*>(buffers[2]);
|
||||
const double alpha = 1.0;
|
||||
ThrowIfErrorStatus(rocblas_dtrsm(handle.get(), d.side, d.uplo, d.trans,
|
||||
d.diag, d.m, d.n, &alpha,
|
||||
const_cast<double*>(a), lda, b, ldb));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocblas_dtrsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m,
|
||||
d.n, &alpha, const_cast<double*>(a), lda, b, ldb)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
@ -190,9 +195,9 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_float_complex* b =
|
||||
static_cast<rocblas_float_complex*>(buffers[2]);
|
||||
const rocblas_float_complex alpha = {1.0f, 0.0f};
|
||||
ThrowIfErrorStatus(rocblas_ctrsm(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<rocblas_float_complex*>(a), lda, b, ldb));
|
||||
const_cast<rocblas_float_complex*>(a), lda, b, ldb)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
@ -200,10 +205,10 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
static_cast<rocblas_double_complex*>(buffers[0]);
|
||||
rocblas_double_complex* b =
|
||||
static_cast<rocblas_double_complex*>(buffers[2]);
|
||||
const rocblas_double_complex alpha = {1.0d, 0.0d};
|
||||
ThrowIfErrorStatus(rocblas_ztrsm(
|
||||
const rocblas_double_complex alpha = {1.0f, 0.0f};
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<rocblas_double_complex*>(a), lda, b, ldb));
|
||||
const_cast<rocblas_double_complex*>(a), lda, b, ldb)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -211,33 +216,35 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
auto a_batch_host =
|
||||
MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
|
||||
SizeOfType(d.type) * lda * lda);
|
||||
JAX_RETURN_IF_ERROR(a_batch_host.status());
|
||||
auto b_batch_host =
|
||||
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(b_batch_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
ThrowIfError(hipStreamSynchronize(stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
|
||||
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
const float alpha = 1.0f;
|
||||
ThrowIfErrorStatus(rocblas_strsm_batched(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_strsm_batched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch));
|
||||
d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
const double alpha = 1.0;
|
||||
ThrowIfErrorStatus(rocblas_dtrsm_batched(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_dtrsm_batched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
|
||||
d.batch));
|
||||
d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
@ -246,10 +253,10 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_float_complex** b_batch_ptrs =
|
||||
static_cast<rocblas_float_complex**>(buffers[4]);
|
||||
const rocblas_float_complex alpha = {1.0f, 0.0f};
|
||||
ThrowIfErrorStatus(rocblas_ctrsm_batched(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm_batched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<rocblas_float_complex**>(a_batch_ptrs), lda,
|
||||
b_batch_ptrs, ldb, d.batch));
|
||||
b_batch_ptrs, ldb, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
@ -257,15 +264,25 @@ void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
static_cast<rocblas_double_complex**>(buffers[3]);
|
||||
rocblas_double_complex** b_batch_ptrs =
|
||||
static_cast<rocblas_double_complex**>(buffers[4]);
|
||||
const rocblas_double_complex alpha = {1.0d, 0.0d};
|
||||
ThrowIfErrorStatus(rocblas_ztrsm_batched(
|
||||
const rocblas_double_complex alpha = {1.0f, 0.0f};
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm_batched(
|
||||
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
|
||||
const_cast<rocblas_double_complex**>(a_batch_ptrs), lda,
|
||||
b_batch_ptrs, ldb, d.batch));
|
||||
b_batch_ptrs, ldb, d.batch)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Trsm(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Trsm_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
//##########################
|
||||
@ -290,17 +307,20 @@ std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
|
||||
return {lwork, PackDescriptor(PotrfDescriptor{type, uplo, b, n})};
|
||||
}
|
||||
|
||||
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const PotrfDescriptor& d =
|
||||
*UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
|
||||
auto handle = rocBlasHandlePool::Borrow(stream);
|
||||
absl::Status Potrf_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const PotrfDescriptor& d = **s;
|
||||
auto h = rocBlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
// a is INOUT, so we copy the input to the output and use that if they are not
|
||||
// already the same
|
||||
if (buffers[1] != buffers[0]) {
|
||||
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.n * d.n,
|
||||
hipMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)))
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[2]);
|
||||
@ -308,28 +328,28 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_spotrf(handle.get(), d.uplo, d.n, a, d.n, info));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_spotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_dpotrf(handle.get(), d.uplo, d.n, a, d.n, info));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_dpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
rocblas_float_complex* a =
|
||||
static_cast<rocblas_float_complex*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_cpotrf(handle.get(), d.uplo, d.n, a, d.n, info));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_cpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
rocblas_double_complex* a =
|
||||
static_cast<rocblas_double_complex*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_zpotrf(handle.get(), d.uplo, d.n, a, d.n, info));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_zpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -337,40 +357,51 @@ void Potrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
auto a_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
|
||||
SizeOfType(d.type) * d.n * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
ThrowIfError(hipStreamSynchronize(stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
|
||||
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
ThrowIfErrorStatus(rocsolver_spotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_spotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
ThrowIfErrorStatus(rocsolver_dpotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dpotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
rocblas_float_complex** a_batch_ptrs =
|
||||
static_cast<rocblas_float_complex**>(buffers[3]);
|
||||
ThrowIfErrorStatus(rocsolver_cpotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cpotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
rocblas_double_complex** a_batch_ptrs =
|
||||
static_cast<rocblas_double_complex**>(buffers[3]);
|
||||
ThrowIfErrorStatus(rocsolver_zpotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zpotrf_batched(
|
||||
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Potrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// getrf: LU decomposition
|
||||
@ -389,18 +420,21 @@ std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
|
||||
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})};
|
||||
}
|
||||
|
||||
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const GetrfDescriptor& d =
|
||||
*UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
|
||||
auto handle = rocBlasHandlePool::Borrow(stream);
|
||||
absl::Status Getrf_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GetrfDescriptor& d = **s;
|
||||
auto h = rocBlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
// a is INOUT, so we copy the input to the output and use that if they are not
|
||||
// already the same
|
||||
if (buffers[1] != buffers[0]) {
|
||||
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)))
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
@ -410,28 +444,28 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_sgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_sgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_dgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_dgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
rocblas_float_complex* a =
|
||||
static_cast<rocblas_float_complex*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_cgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_cgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
rocblas_double_complex* a =
|
||||
static_cast<rocblas_double_complex*>(buffers[1]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_zgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_zgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -439,44 +473,55 @@ void Getrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
auto a_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
|
||||
SizeOfType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
ThrowIfError(hipStreamSynchronize(stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
|
||||
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float** batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_sgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
ipiv, std::min(d.m, d.n), info, d.batch));
|
||||
ipiv, std::min(d.m, d.n), info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double** batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_dgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
ipiv, std::min(d.m, d.n), info, d.batch));
|
||||
ipiv, std::min(d.m, d.n), info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
rocblas_float_complex** batch_ptrs =
|
||||
static_cast<rocblas_float_complex**>(buffers[4]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_cgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
ipiv, std::min(d.m, d.n), info, d.batch));
|
||||
ipiv, std::min(d.m, d.n), info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
rocblas_double_complex** batch_ptrs =
|
||||
static_cast<rocblas_double_complex**>(buffers[4]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_zgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
ipiv, std::min(d.m, d.n), info, d.batch));
|
||||
ipiv, std::min(d.m, d.n), info, d.batch)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Getrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// geqrf: QR decomposition
|
||||
@ -494,18 +539,21 @@ std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
|
||||
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n})};
|
||||
}
|
||||
|
||||
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const GeqrfDescriptor& d =
|
||||
*UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
|
||||
auto handle = rocBlasHandlePool::Borrow(stream);
|
||||
absl::Status Geqrf_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GeqrfDescriptor& d = **s;
|
||||
auto h = rocBlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
// a is INOUT, so we copy the input to the output and use that if they are not
|
||||
// already the same
|
||||
if (buffers[1] != buffers[0]) {
|
||||
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)))
|
||||
}
|
||||
|
||||
// here tau is tau
|
||||
@ -515,15 +563,15 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* tau = static_cast<float*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_sgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_sgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* tau = static_cast<double*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_dgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_dgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
@ -531,8 +579,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
static_cast<rocblas_float_complex*>(buffers[1]);
|
||||
rocblas_float_complex* tau =
|
||||
static_cast<rocblas_float_complex*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_cgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_cgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
@ -540,8 +588,8 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
static_cast<rocblas_double_complex*>(buffers[1]);
|
||||
rocblas_double_complex* tau =
|
||||
static_cast<rocblas_double_complex*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_zgeqrf(handle.get(), d.m, d.n, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(rocsolver_zgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -549,26 +597,27 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
auto a_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
|
||||
SizeOfType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
ThrowIfError(hipStreamSynchronize(stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
|
||||
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float** batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
float* tau = static_cast<float*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_sgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
tau, std::min(d.m, d.n), d.batch));
|
||||
tau, std::min(d.m, d.n), d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double** batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
double* tau = static_cast<double*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_dgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
tau, std::min(d.m, d.n), d.batch));
|
||||
tau, std::min(d.m, d.n), d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
@ -576,9 +625,9 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
static_cast<rocblas_float_complex**>(buffers[3]);
|
||||
rocblas_float_complex* tau =
|
||||
static_cast<rocblas_float_complex*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_cgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
tau, std::min(d.m, d.n), d.batch));
|
||||
tau, std::min(d.m, d.n), d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
@ -586,13 +635,23 @@ void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
static_cast<rocblas_double_complex**>(buffers[3]);
|
||||
rocblas_double_complex* tau =
|
||||
static_cast<rocblas_double_complex*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_zgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
|
||||
tau, std::min(d.m, d.n), d.batch));
|
||||
tau, std::min(d.m, d.n), d.batch)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// orgqr/ungqr: apply elementary Householder transformations
|
||||
@ -608,18 +667,21 @@ std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
|
||||
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k})};
|
||||
}
|
||||
|
||||
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const OrgqrDescriptor& d =
|
||||
*UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
|
||||
auto handle = rocBlasHandlePool::Borrow(stream);
|
||||
absl::Status Orgqr_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const OrgqrDescriptor& d = **s;
|
||||
auto h = rocBlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
// a is INOUT, so we copy the input to the output and use that if they are not
|
||||
// already the same
|
||||
if (buffers[2] != buffers[0]) {
|
||||
ThrowIfError(hipMemcpyAsync(buffers[2], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
|
||||
buffers[2], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)))
|
||||
}
|
||||
|
||||
switch (d.type) {
|
||||
@ -629,8 +691,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
float* a = static_cast<float*>(buffers[2]);
|
||||
float* tau = static_cast<float*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_sorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_sorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
}
|
||||
@ -640,8 +702,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
double* a = static_cast<double*>(buffers[2]);
|
||||
double* tau = static_cast<double*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_dorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_dorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
}
|
||||
@ -656,8 +718,8 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_float_complex* tau =
|
||||
static_cast<rocblas_float_complex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_cungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_cungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
}
|
||||
@ -669,14 +731,24 @@ void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_double_complex* tau =
|
||||
static_cast<rocblas_double_complex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(
|
||||
rocsolver_zungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_zungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
|
||||
@ -715,18 +787,21 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
|
||||
return {lwork, PackDescriptor(GesvdDescriptor{type, b, m, n, jobu, jobvt})};
|
||||
}
|
||||
|
||||
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const GesvdDescriptor& d =
|
||||
*UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
|
||||
auto handle = rocBlasHandlePool::Borrow(stream);
|
||||
absl::Status Gesvd_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GesvdDescriptor& d = **s;
|
||||
auto h = rocBlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
// a is INOUT, so we copy the input to the output and use that if they are not
|
||||
// already the same
|
||||
if (buffers[1] != buffers[0]) {
|
||||
ThrowIfError(hipMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)))
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
@ -743,9 +818,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
float* u = static_cast<float*>(buffers[3]);
|
||||
float* vt = static_cast<float*>(buffers[4]);
|
||||
float* e = static_cast<float*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_sgesvd(handle.get(), d.jobu, d.jobvt, d.m,
|
||||
d.n, a, lda, s, u, ldu, vt, ldv, e,
|
||||
rocblas_inplace, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_sgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
|
||||
u, ldu, vt, ldv, e, rocblas_inplace, info)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
@ -754,9 +829,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
double* u = static_cast<double*>(buffers[3]);
|
||||
double* vt = static_cast<double*>(buffers[4]);
|
||||
double* e = static_cast<double*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_dgesvd(handle.get(), d.jobu, d.jobvt, d.m,
|
||||
d.n, a, lda, s, u, ldu, vt, ldv, e,
|
||||
rocblas_inplace, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_dgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
|
||||
u, ldu, vt, ldv, e, rocblas_inplace, info)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
@ -768,9 +843,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_float_complex* vt =
|
||||
static_cast<rocblas_float_complex*>(buffers[4]);
|
||||
float* e = static_cast<float*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_cgesvd(handle.get(), d.jobu, d.jobvt, d.m,
|
||||
d.n, a, lda, s, u, ldu, vt, ldv, e,
|
||||
rocblas_inplace, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_cgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
|
||||
u, ldu, vt, ldv, e, rocblas_inplace, info)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
@ -782,9 +857,9 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_double_complex* vt =
|
||||
static_cast<rocblas_double_complex*>(buffers[4]);
|
||||
double* e = static_cast<double*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_zgesvd(handle.get(), d.jobu, d.jobvt, d.m,
|
||||
d.n, a, lda, s, u, ldu, vt, ldv, e,
|
||||
rocblas_inplace, info));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(
|
||||
rocsolver_zgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
|
||||
u, ldu, vt, ldv, e, rocblas_inplace, info)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -797,10 +872,11 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
auto a_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[1], buffers[7], d.batch,
|
||||
SizeOfType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
ThrowIfError(hipStreamSynchronize(stream));
|
||||
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
|
||||
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
@ -809,10 +885,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
float* u = static_cast<float*>(buffers[3]);
|
||||
float* vt = static_cast<float*>(buffers[4]);
|
||||
float* e = static_cast<float*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_sgesvd_batched(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_sgesvd_batched(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
|
||||
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
|
||||
rocblas_inplace, info, d.batch));
|
||||
rocblas_inplace, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
@ -821,10 +897,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
double* u = static_cast<double*>(buffers[3]);
|
||||
double* vt = static_cast<double*>(buffers[4]);
|
||||
double* e = static_cast<double*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_dgesvd_batched(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dgesvd_batched(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
|
||||
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
|
||||
rocblas_inplace, info, d.batch));
|
||||
rocblas_inplace, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
@ -836,10 +912,10 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_float_complex* vt =
|
||||
static_cast<rocblas_float_complex*>(buffers[4]);
|
||||
float* e = static_cast<float*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_cgesvd_batched(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cgesvd_batched(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
|
||||
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
|
||||
rocblas_inplace, info, d.batch));
|
||||
rocblas_inplace, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
@ -851,14 +927,24 @@ void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
rocblas_double_complex* vt =
|
||||
static_cast<rocblas_double_complex*>(buffers[4]);
|
||||
double* e = static_cast<double*>(buffers[6]);
|
||||
ThrowIfErrorStatus(rocsolver_zgesvd_batched(
|
||||
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zgesvd_batched(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
|
||||
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
|
||||
rocblas_inplace, info, d.batch));
|
||||
rocblas_inplace, info, d.batch)))
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
|
||||
s.error_message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Singular value decomposition using Jacobi algorithm: gesvdj
|
||||
|
@ -22,24 +22,26 @@ limitations under the License.
|
||||
|
||||
namespace jax {
|
||||
|
||||
void ThrowIfError(hipError_t error) {
|
||||
absl::Status AsStatus(hipError_t error) {
|
||||
if (error != hipSuccess) {
|
||||
throw std::runtime_error(
|
||||
return absl::InternalError(
|
||||
absl::StrCat("ROCm operation failed: ", hipGetErrorString(error)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::unique_ptr<void* []> MakeBatchPointers(hipStream_t stream, void* buffer,
|
||||
void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
absl::StatusOr<std::unique_ptr<void* []>> MakeBatchPointers(
|
||||
hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
char* ptr = static_cast<char*>(buffer);
|
||||
auto host_ptrs = absl::make_unique<void*[]>(batch);
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
host_ptrs[i] = ptr;
|
||||
ptr += batch_elem_size;
|
||||
}
|
||||
ThrowIfError(hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
|
||||
hipMemcpyHostToDevice, stream));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
AsStatus(hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
|
||||
hipMemcpyHostToDevice, stream)));
|
||||
return host_ptrs;
|
||||
}
|
||||
} // namespace jax
|
||||
|
@ -19,18 +19,28 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
#define JAX_RETURN_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto s___ = (expr); \
|
||||
if (!s___.ok()) return s___; \
|
||||
}
|
||||
|
||||
namespace jax {
|
||||
|
||||
void ThrowIfError(hipError_t error);
|
||||
absl::Status AsStatus(hipError_t error);
|
||||
|
||||
// Builds an array of pointers to each array in a batch, in device memory.
|
||||
// Caution: the return value must be kept alive (e.g., via a stream
|
||||
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
|
||||
// completes.
|
||||
std::unique_ptr<void*[]> MakeBatchPointers(hipStream_t stream, void* buffer,
|
||||
void* dev_ptrs, int batch,
|
||||
int batch_elem_size);
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(hipStream_t stream,
|
||||
void* buffer,
|
||||
void* dev_ptrs,
|
||||
int batch,
|
||||
int batch_elem_size);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
|
@ -96,7 +96,9 @@ def trsm(c,
|
||||
_Shape.array_shape(dtype, a_shape.dimensions(), layout), # buffers[0] (a)
|
||||
_Shape.array_shape(dtype, b_shape.dimensions(), layout), # buffers[1] (b, IN)
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
@ -133,7 +135,9 @@ def potrf(c, a, lower):
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
), # buffers[0] (a, IN)
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)
|
||||
|
||||
|
||||
@ -171,7 +175,9 @@ def getrf(c, a):
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
), # buffers[0] (a, IN)
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
|
||||
_ops.GetTupleElement(out, 2))
|
||||
|
||||
@ -208,7 +214,9 @@ def geqrf(c, a):
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
), # buffers[0] (a, IN)
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
# rocsolver geqrf does not return info
|
||||
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), None)
|
||||
|
||||
@ -247,7 +255,9 @@ def orgqr(c, a, tau):
|
||||
_Shape.array_shape(dtype, batch_dims + (k,),
|
||||
tuple(range(num_bd, -1, -1))), # buffers[1] (tau IN)
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
return (_ops.GetTupleElement(out, 0), None) # ROCSolver orgqr does not return info
|
||||
|
||||
|
||||
@ -303,7 +313,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n),
|
||||
matrix_layout), # buffers[0] (a, IN)
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
s = _ops.GetTupleElement(out, 1)
|
||||
vt = _ops.GetTupleElement(out, 2)
|
||||
u = _ops.GetTupleElement(out, 3)
|
||||
@ -338,7 +350,9 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n),
|
||||
matrix_layout), # buffers[0] (a, IN)
|
||||
),
|
||||
opaque=opaque)
|
||||
opaque=opaque,
|
||||
api_version=xla_client.ops.CustomCallApiVersion
|
||||
.API_VERSION_STATUS_RETURNING)
|
||||
s = _ops.GetTupleElement(out, 1)
|
||||
u = _ops.GetTupleElement(out, 2)
|
||||
vt = _ops.GetTupleElement(out, 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user