Merge CUDA and ROCM kernel code in jaxlib.

The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
This commit is contained in:
Peter Hawkins 2022-10-25 07:23:07 -07:00 committed by jax authors
parent 621f06660d
commit a852710a09
48 changed files with 1918 additions and 4852 deletions

View File

@ -184,25 +184,25 @@ def prepare_wheel(sources_path):
copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir)
cuda_dir = os.path.join(jaxlib_dir, "cuda")
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
os.makedirs(libdevice_dir)
copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir)
copy_file(f"__main__/jaxlib/cuda/_cusolver.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_cublas.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_cuda_linalg.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_cuda_prng.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir)
copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir)
rocm_dir = os.path.join(jaxlib_dir, "rocm")
if exists(f"__main__/jaxlib/rocm/_hipsolver.{pyext}"):
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
os.makedirs(rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hipsolver.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hipblas.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hip_linalg.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_hip_prng.{pyext}", dst_dir=rocm_dir)
if exists(f"__main__/jaxlib/cuda/_cusparse.{pyext}"):
copy_file(f"__main__/jaxlib/cuda/_cusparse.{pyext}", dst_dir=cuda_dir)
if exists(f"__main__/jaxlib/rocm/_hipsparse.{pyext}"):
copy_file(f"__main__/jaxlib/rocm/_hipsparse.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir)
copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir)
if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"):
copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir)
if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"):
copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir)
mlir_dir = os.path.join(jaxlib_dir, "mlir")

View File

@ -24,15 +24,31 @@ licenses(["notice"])
package(default_visibility = ["//:__subpackages__"])
cc_library(
name = "cuda_vendor",
hdrs = [
"//jaxlib/gpu:vendor.h",
],
defines = ["JAX_GPU_CUDA=1"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
],
)
cc_library(
name = "cuda_gpu_kernel_helpers",
srcs = ["cuda_gpu_kernel_helpers.cc"],
hdrs = ["cuda_gpu_kernel_helpers.h"],
srcs = [
"//jaxlib/gpu:gpu_kernel_helpers.cc",
],
hdrs = [
"//jaxlib/gpu:gpu_kernel_helpers.h",
],
copts = [
"-fexceptions",
],
features = ["-use_header_modules"],
deps = [
":cuda_vendor",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
"@com_google_absl//absl/memory",
@ -47,10 +63,11 @@ cc_library(
cc_library(
name = "cublas_kernels",
srcs = ["cublas_kernels.cc"],
hdrs = ["cublas_kernels.h"],
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
@ -71,31 +88,32 @@ cc_library(
)
pybind_extension(
name = "_cublas",
srcs = ["cublas.cc"],
name = "_blas",
srcs = ["//jaxlib/gpu:blas.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cublas",
module_name = "_blas",
deps = [
":cublas_kernels",
":cuda_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@pybind11",
],
)
cc_library(
name = "cusolver_kernels",
srcs = ["cusolver_kernels.cc"],
hdrs = ["cusolver_kernels.h"],
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
@ -108,16 +126,17 @@ cc_library(
)
pybind_extension(
name = "_cusolver",
srcs = ["cusolver.cc"],
name = "_solver",
srcs = ["//jaxlib/gpu:solver.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cusolver",
module_name = "_solver",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":cusolver_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
@ -131,10 +150,11 @@ pybind_extension(
cc_library(
name = "cusparse_kernels",
srcs = ["cusparse_kernels.cc"],
hdrs = ["cusparse_kernels.h"],
srcs = ["//jaxlib/gpu:sparse_kernels.cc"],
hdrs = ["//jaxlib/gpu:sparse_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
@ -148,16 +168,17 @@ cc_library(
)
pybind_extension(
name = "_cusparse",
srcs = ["cusparse.cc"],
name = "_sparse",
srcs = ["//jaxlib/gpu:sparse.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cusparse",
module_name = "_sparse",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
":cusparse_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
@ -179,12 +200,13 @@ pybind_extension(
cc_library(
name = "cuda_lu_pivot_kernels",
srcs = [
"cuda_lu_pivot_kernels.cc",
"//jaxlib/gpu:lu_pivot_kernels.cc",
],
hdrs = ["cuda_lu_pivot_kernels.h"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_lu_pivot_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
@ -194,11 +216,12 @@ cc_library(
cuda_library(
name = "cuda_lu_pivot_kernels_impl",
srcs = [
"cuda_lu_pivot_kernels.cu.cc",
"//jaxlib/gpu:lu_pivot_kernels.cu.cc",
],
hdrs = ["cuda_lu_pivot_kernels.h"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
@ -206,18 +229,19 @@ cuda_library(
)
pybind_extension(
name = "_cuda_linalg",
srcs = ["cuda_linalg.cc"],
name = "_linalg",
srcs = ["//jaxlib/gpu:linalg.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cuda_linalg",
module_name = "_linalg",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_lu_pivot_kernels",
":cuda_lu_pivot_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@local_config_cuda//cuda:cuda_headers",
@ -228,12 +252,13 @@ pybind_extension(
cc_library(
name = "cuda_prng_kernels",
srcs = [
"cuda_prng_kernels.cc",
"//jaxlib/gpu:prng_kernels.cc",
],
hdrs = ["cuda_prng_kernels.h"],
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_prng_kernels_impl",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
@ -243,11 +268,12 @@ cc_library(
cuda_library(
name = "cuda_prng_kernels_impl",
srcs = [
"cuda_prng_kernels.cu.cc",
"//jaxlib/gpu:prng_kernels.cu.cc",
],
hdrs = ["cuda_prng_kernels.h"],
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_cuda//cuda:cuda_headers",
@ -255,14 +281,14 @@ cuda_library(
)
pybind_extension(
name = "_cuda_prng",
srcs = ["cuda_prng.cc"],
name = "_prng",
srcs = ["//jaxlib/gpu:prng.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_cuda_prng",
module_name = "_prng",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_prng_kernels",
@ -275,12 +301,13 @@ pybind_extension(
cc_library(
name = "cuda_gpu_kernels",
srcs = ["cuda_gpu_kernels.cc"],
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
visibility = ["//visibility:public"],
deps = [
":cublas_kernels",
":cuda_lu_pivot_kernels",
":cuda_prng_kernels",
":cuda_vendor",
":cusolver_kernels",
":cusparse_kernels",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
@ -291,10 +318,10 @@ cc_library(
py_library(
name = "cuda_gpu_support",
deps = [
":_cublas",
":_cuda_linalg",
":_cuda_prng",
":_cusolver",
":_cusparse",
":_blas",
":_linalg",
":_prng",
":_solver",
":_sparse",
],
)

View File

@ -1,84 +0,0 @@
/* Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "jaxlib/cuda/cublas_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
namespace jax {
namespace {
namespace py = pybind11;
// Converts a NumPy dtype to a Type.
CublasType DtypeToCublasType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, CublasType>({
{{'f', 4}, CublasType::F32},
{{'f', 8}, CublasType::F64},
{{'c', 8}, CublasType::C64},
{{'c', 16}, CublasType::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
}
return it->second;
}
// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
int b, int n) {
CublasType type = DtypeToCublasType(dtype);
size_t size = b * sizeof(void*);
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
}
// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
int b, int m, int n) {
CublasType type = DtypeToCublasType(dtype);
size_t size = b * sizeof(void*);
return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})};
}
py::dict Registrations() {
py::dict dict;
dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict["cublas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
return dict;
}
PYBIND11_MODULE(_cublas, m) {
m.def("registrations", &Registrations);
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
}
} // namespace
} // namespace jax

View File

@ -1,225 +0,0 @@
/* Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cublas_kernels.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
using BlasHandlePool = HandlePool<cublasHandle_t, cudaStream_t>;
template <>
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
cudaStream_t stream) {
BlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cublasHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
namespace {
// Converts a NumPy dtype to a CublasType.
int SizeOfCublasType(CublasType type) {
switch (type) {
case CublasType::F32:
return sizeof(float);
case CublasType::F64:
return sizeof(double);
case CublasType::C64:
return sizeof(cuComplex);
case CublasType::C128:
return sizeof(cuDoubleComplex);
}
}
} // namespace
// Batched LU decomposition: getrfbatched
static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GetrfBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.n * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
SizeOfCublasType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
switch (d.type) {
case CublasType::F32: {
float** batch_ptrs = static_cast<float**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case CublasType::F64: {
double** batch_ptrs = static_cast<double**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case CublasType::C64: {
cuComplex** batch_ptrs = static_cast<cuComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case CublasType::C128: {
cuDoubleComplex** batch_ptrs = static_cast<cuDoubleComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
}
return absl::OkStatus();
}
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Batched QR decomposition: geqrfbatched
static absl::Status GeqrfBatched_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GeqrfBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GeqrfBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream)));
}
std::vector<int> info(d.batch);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfCublasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
auto tau_ptrs_host =
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfCublasType(d.type) * std::min(d.m, d.n));
JAX_RETURN_IF_ERROR(tau_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
switch (d.type) {
case CublasType::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** tau_batch_ptrs = static_cast<float**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case CublasType::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** tau_batch_ptrs = static_cast<double**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case CublasType::C64: {
cuComplex** a_batch_ptrs = static_cast<cuComplex**>(buffers[3]);
cuComplex** tau_batch_ptrs = static_cast<cuComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case CublasType::C128: {
cuDoubleComplex** a_batch_ptrs =
static_cast<cuDoubleComplex**>(buffers[3]);
cuDoubleComplex** tau_batch_ptrs =
static_cast<cuDoubleComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
}
auto it =
std::find_if(info.begin(), info.end(), [](int i) { return i != 0; });
if (it != info.end()) {
return absl::InvalidArgumentError(
absl::StrFormat("QR decomposition failed with status %d for batch "
"element %d",
*it, std::distance(info.begin(), it)));
}
return absl::OkStatus();
}
void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace jax

View File

@ -1,140 +0,0 @@
/* Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include <stdexcept>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
namespace jax {
namespace {
std::string ErrorString(cudaError_t error) { return cudaGetErrorString(error); }
std::string ErrorString(cusparseStatus_t status) {
return cusparseGetErrorString(status);
}
std::string ErrorString(cusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS:
return "cuSolver success.";
case CUSOLVER_STATUS_NOT_INITIALIZED:
return "cuSolver has not been initialized";
case CUSOLVER_STATUS_ALLOC_FAILED:
return "cuSolver allocation failed";
case CUSOLVER_STATUS_INVALID_VALUE:
return "cuSolver invalid value error";
case CUSOLVER_STATUS_ARCH_MISMATCH:
return "cuSolver architecture mismatch error";
case CUSOLVER_STATUS_MAPPING_ERROR:
return "cuSolver mapping error";
case CUSOLVER_STATUS_EXECUTION_FAILED:
return "cuSolver execution failed";
case CUSOLVER_STATUS_INTERNAL_ERROR:
return "cuSolver internal error";
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "cuSolver matrix type not supported error";
case CUSOLVER_STATUS_NOT_SUPPORTED:
return "cuSolver not supported error";
case CUSOLVER_STATUS_ZERO_PIVOT:
return "cuSolver zero pivot error";
case CUSOLVER_STATUS_INVALID_LICENSE:
return "cuSolver invalid license error";
default:
return absl::StrCat("Unknown cuSolver error: ", status);
}
}
std::string ErrorString(cublasStatus_t status) {
switch (status) {
case CUBLAS_STATUS_SUCCESS:
return "cuBlas success";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "cuBlas has not been initialized";
case CUBLAS_STATUS_ALLOC_FAILED:
return "cuBlas allocation failure";
case CUBLAS_STATUS_INVALID_VALUE:
return "cuBlas invalid value error";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "cuBlas architecture mismatch";
case CUBLAS_STATUS_MAPPING_ERROR:
return "cuBlas mapping error";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "cuBlas execution failed";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "cuBlas internal error";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "cuBlas not supported error";
case CUBLAS_STATUS_LICENSE_ERROR:
return "cuBlas license error";
default:
return "Unknown cuBlas error";
}
}
template <typename T>
std::string ErrorString(T status, const char* file, std::int64_t line,
const char* expr) {
return absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr,
ErrorString(status));
}
} // namespace
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
const char* expr) {
if (error != cudaSuccess)
return absl::InternalError(ErrorString(error, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(cusolverStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != CUSOLVER_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(cusparseStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != CUSPARSE_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(cublasStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != CUBLAS_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<void* []>> MakeBatchPointers(
cudaStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
host_ptrs[i] = ptr;
ptr += batch_elem_size;
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudaMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
cudaMemcpyHostToDevice, stream)));
return std::move(host_ptrs);
}
} // namespace jax

View File

@ -1,51 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
#include <string_view>
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace {
absl::Status CudaLuPivotsToPermutation_(cudaStream_t stream, void** buffers,
const char* opaque,
std::size_t opaque_len) {
auto s =
UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
LaunchLuPivotsToPermutationKernel(stream, buffers, **s);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
return absl::OkStatus();
}
} // namespace
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status) {
auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
std::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
} // namespace jax

View File

@ -1,77 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
#include <array>
#include <iostream>
namespace jax {
namespace {
__device__ void ComputePermutation(const std::int32_t* pivots,
std::int32_t* permutation_out,
const std::int32_t pivot_size,
const std::int32_t permutation_size) {
for (int i = 0; i < permutation_size; ++i) {
permutation_out[i] = i;
}
// Compute the permutation from a sequence of transpositions encoded in the
// pivot array by applying the transpositions in order on the identity
// permutation.
for (int i = 0; i < pivot_size; ++i) {
if ((pivots[i] < 0) || (pivots[i] >= permutation_size)) {
continue;
}
std::int32_t swap_temporary = permutation_out[i];
permutation_out[i] = permutation_out[pivots[i]];
permutation_out[pivots[i]] = swap_temporary;
}
}
__global__ void LuPivotsToPermutationKernel(
const std::int32_t* pivots, std::int32_t* permutation_out,
const std::int64_t batch_size, const std::int32_t pivot_size,
const std::int32_t permutation_size) {
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < batch_size; idx += blockDim.x * gridDim.x) {
// Fill in the output array with the identity permutation.
ComputePermutation(pivots + idx * pivot_size,
permutation_out + idx * permutation_size, pivot_size,
permutation_size);
}
}
} // namespace
void LaunchLuPivotsToPermutationKernel(
cudaStream_t stream, void** buffers,
LuPivotsToPermutationDescriptor descriptor) {
const std::int32_t* pivots =
reinterpret_cast<const std::int32_t*>(buffers[0]);
std::int32_t* permutation_out = reinterpret_cast<std::int32_t*>(buffers[1]);
const int block_dim = 128;
const std::int64_t grid_dim = std::min<std::int64_t>(
1024, (descriptor.batch_size + block_dim - 1) / block_dim);
LuPivotsToPermutationKernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(
pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size,
descriptor.permutation_size);
}
} // namespace jax

View File

@ -1,43 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_CUDA_LU_PIVOT_KERNELS_H_
#define JAXLIB_CUDA_LU_PIVOT_KERNELS_H_
#include <cstddef>
#include <string>
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
struct LuPivotsToPermutationDescriptor {
std::int64_t batch_size;
std::int32_t pivot_size;
std::int32_t permutation_size;
};
void LaunchLuPivotsToPermutationKernel(
cudaStream_t stream, void** buffers,
LuPivotsToPermutationDescriptor descriptor);
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_CUDA_LU_PIVOT_KERNELS_H_

View File

@ -1,618 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/gpus/cuda/include/cusparse.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/cuda/cusparse_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
namespace py = pybind11;
namespace jax {
namespace {
cusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, cusparseIndexType_t>({
{{'u', 2}, CUSPARSE_INDEX_16U},
{{'i', 4}, CUSPARSE_INDEX_32I},
{{'i', 8}, CUSPARSE_INDEX_64I},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type)));
}
return it->second;
}
cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, cudaDataType>({
{{'f', 2}, CUDA_R_16F}, {{'c', 4}, CUDA_C_16F}, {{'f', 4}, CUDA_R_32F},
{{'c', 8}, CUDA_C_32F}, {{'f', 8}, CUDA_R_64F},
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
{{'u', 4}, CUDA_R_32U},
#if JAX_CUSPARSE_11300
{{'V', 2}, CUDA_R_16BF},
#endif
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type)));
}
return it->second;
}
// Returns the descriptor for a Sparse matrix.
SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
const py::dtype& index_dtype,
int rows, int cols, int nnz,
int batch_count,
int batch_stride) {
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
cusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
return SparseMatDescriptor{
value_type, index_type, rows, cols, nnz, batch_count, batch_stride};
}
// Returns the descriptor for a Dense matrix.
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
int rows, int cols, int batch_count,
int batch_stride) {
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride};
}
// Returns the descriptor for a Dense vector.
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
int size) {
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseVecDescriptor{value_type, size};
}
#if JAX_CUSPARSE_11300
// CsrToDense: Convert CSR matrix to dense matrix
// Returns the descriptor for a Sparse matrix.
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count*/1, /*batch_stride*/0);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
// buffer_size does not reference these pointers, but does error on NULL.
// TODO(jakevdp): check whether this is documented.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrFromDense: Convert dense matrix to CSR matrix
// Returns the descriptor for a CsrFromDense operation.
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
absl::Status CsrFromDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrMatvec: Product of CSR matrix and dense vector.
// Returns the descriptor for a CsrMatvec operation.
std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
DenseVecDescriptor x =
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
CudaConst alpha = CudaOne(y.type);
CudaConst beta = CudaZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
}
// CsrMatmat: Product of CSR matrix and dense matrix.
// Returns the descriptor for a CsrMatmat operation.
std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
DenseMatDescriptor B =
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols,
/*batch_count=*/1, /*batch_stride=*/0);
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols,
/*batch_count=*/1, /*batch_stride=*/0);
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW)));
size_t buffer_size;
CudaConst alpha = CudaOne(C.type);
CudaConst beta = CudaZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
}
// CooToDense: Convert COO matrix to dense matrix
// Returns the descriptor for a CooToDense operation.
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
// CooFromDense: Convert dense matrix to COO matrix
// Returns the descriptor for a CooFromDense operation.
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
// CooMatvec: Product of COO matrix and dense vector.
// Returns the descriptor for a CooMatvec operation.
std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
DenseVecDescriptor x =
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
CudaConst alpha = CudaOne(y.type);
CudaConst beta = CudaZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
}
// CooMatmat: Product of COO matrix and dense matrix.
// Returns the descriptor for a CooMatmat operation.
std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose, int batch_count,
int lhs_batch_stride, int rhs_batch_stride) {
// Three batch modes are supported, C_i = A_i B, C_i = A B_i, and
// Ci = A_i B_i, where `i` denotes the batch dimension.
// All three matrices A, B, and C must have the same batch count.
// Use batch stride to trigger individual mode, e.g.,
// `rhs_batch_stride = 0` for C_i = A_i B.
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
batch_count, lhs_batch_stride);
DenseMatDescriptor B =
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols,
batch_count, rhs_batch_stride);
int C_rows = (transpose == true) ? cols : rows;
// TODO(tianjianlu): enable the selection of batch stride.
// The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
// in cusparse library does not allow batch_stride = 0.
// int C_batch_stride = (batch_count > 1)? C_rows * BCcols : 0;
int C_batch_stride = C_rows * BCcols;
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, /*rows=*/C_rows, /*cols=*/BCcols,
batch_count, C_batch_stride);
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
: CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCooSetStridedBatch(
mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(
mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
size_t buffer_size;
CudaConst alpha = CudaOne(C.type);
CudaConst beta = CudaZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
}
#endif // if JAX_CUSPARSE_11300
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
}
template <typename F>
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
size_t size;
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
/*du=*/nullptr, /*B=*/nullptr, ldb, &size)));
return size;
}
size_t Gtsv2BufferSizeF32(int m, int n, int ldb) {
return Gtsv2BufferSize(cusparseSgtsv2_bufferSizeExt, m, n, ldb);
}
size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
return Gtsv2BufferSize(cusparseDgtsv2_bufferSizeExt, m, n, ldb);
}
py::dict Registrations() {
py::dict dict;
#if JAX_CUSPARSE_11300
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
dict["cusparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
dict["cusparse_coo_todense"] = EncapsulateFunction(CooToDense);
dict["cusparse_coo_fromdense"] = EncapsulateFunction(CooFromDense);
dict["cusparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
dict["cusparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
#endif
dict["cusparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
dict["cusparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
return dict;
}
PYBIND11_MODULE(_cusparse, m) {
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11300);
m.def("registrations", &Registrations);
#if JAX_CUSPARSE_11300
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
m.def("build_csr_matmat_descriptor", &BuildCsrMatmatDescriptor);
m.def("build_coo_todense_descriptor", &BuildCooToDenseDescriptor);
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
#endif
m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32);
m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64);
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
}
} // namespace
} // namespace jax

View File

@ -1,618 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cusparse_kernels.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#if JAX_CUDA_11080
#include "third_party/gpus/cuda/include/cuda_fp8.h"
#endif
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
// cuSPARSE generic APIs are not supported on Windows until 11.0
// cusparseIndexType_t is used in very limited scope so manually define will
// workaround compiling issue without harm.
#if defined(_WIN32) && (CUSPARSE_VERSION < 11000)
typedef enum {
CUSPARSE_INDEX_16U = 1,
CUSPARSE_INDEX_32I = 2,
CUSPARSE_INDEX_64I = 3
} cusparseIndexType_t;
#endif
namespace jax {
template <>
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
cudaStream_t stream) {
SparseHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cusparseHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
CudaConst CudaZero(cudaDataType type) {
CudaConst c;
std::memset(&c, 0, sizeof(c));
return c;
}
CudaConst CudaOne(cudaDataType type) {
CudaConst c;
std::memset(&c, 0, sizeof(c));
switch (type) {
#if JAX_CUSPARSE_11300
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
case CUDA_R_4I:
case CUDA_C_4I:
#endif
case CUDA_R_8I:
case CUDA_C_8I:
c.i8[0] = 1;
break;
#if JAX_CUSPARSE_11300
case CUDA_R_4U:
case CUDA_C_4U:
#endif
case CUDA_R_8U:
case CUDA_C_8U:
c.u8[0] = 1;
break;
#if JAX_CUSPARSE_11300
case CUDA_R_16I:
case CUDA_C_16I:
c.i16[0] = 1;
break;
case CUDA_R_16U:
case CUDA_C_16U:
c.u16[0] = 1;
break;
#endif
case CUDA_R_32I:
case CUDA_C_32I:
c.i32[0] = 1;
break;
case CUDA_R_32U:
case CUDA_C_32U:
c.u32[0] = 1;
break;
#if JAX_CUSPARSE_11300
case CUDA_R_64I:
case CUDA_C_64I:
c.i64[0] = 1;
break;
case CUDA_R_64U:
case CUDA_C_64U:
c.u64[0] = 1;
break;
#endif
#if JAX_CUDA_11080
case CUDA_R_8F_E4M3:
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E4M3);
break;
case CUDA_R_8F_E5M2:
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E5M2);
break;
#endif
// TODO(jakevdp): 16F/16BF here might break on big endian platforms.
case CUDA_R_16F:
case CUDA_C_16F:
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
break;
#if JAX_CUSPARSE_11300
case CUDA_R_16BF:
case CUDA_C_16BF:
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
break;
#endif
case CUDA_R_32F:
case CUDA_C_32F:
c.f32[0] = 1.0;
break;
case CUDA_R_64F:
case CUDA_C_64F:
c.f64[0] = 1.0;
break;
}
return c;
}
#if JAX_CUSPARSE_11300
// CsrToDense: Convert CSR matrix to dense matrix
static absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrFromDense: Convert dense matrix to CSR matrix
static absl::Status CsrFromDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrMatvec: Product of CSR matrix and dense vector.
static absl::Status CsrMatvec_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CsrMatvecDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
void* csr_row_offsets = buffers[2];
void* xbuf = buffers[3];
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.y.type);
CudaConst beta = CudaZero(d.y.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrMatmat: Product of CSR matrix and dense matrix.
static absl::Status CsrMatmat_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CsrMatmatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
void* csr_row_offsets = buffers[2];
void* Bbuf = buffers[3];
void* Cbuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.C.type);
CudaConst beta = CudaZero(d.C.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooToDense: Convert COO matrix to dense matrix
static absl::Status CooToDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[1],
/*cooColInd=*/buffers[2],
/*cooValues=*/buffers[0], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSparseToDense(handle.get(), mat_a, mat_b,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooFromDense: Convert dense matrix to COO matrix
static absl::Status CooFromDense_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cusparseDnMatDescr_t mat_a = 0;
cusparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[2],
/*cooColInd=*/buffers[3],
/*cooValues=*/buffers[1], d.index_type,
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooMatvec: Product of COO matrix and dense vector.
static absl::Status CooMatvec_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CooMatvecDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
void* coo_col_ind = buffers[2];
void* xbuf = buffers[3];
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.y.type);
CudaConst beta = CudaZero(d.y.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnVecDescr_t vec_x = 0;
cusparseDnVecDescr_t vec_y = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooMatmat: Product of COO matrix and dense matrix.
static absl::Status CooMatmat_(cudaStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CooMatmatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
void* coo_col_ind = buffers[2];
void* Bbuf = buffers[3];
void* Cbuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
CudaConst alpha = CudaOne(d.C.type);
CudaConst beta = CudaZero(d.C.type);
cusparseSpMatDescr_t mat_a = 0;
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
/*batchStride=*/d.A.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
/*batchStride=*/d.B.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
/*batchStride=*/d.C.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
#endif // if JAX_CUSPARSE_11300
template <typename T, typename F>
static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto h = SparseHandlePool::Borrow();
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
auto s = UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const Gtsv2Descriptor& descriptor = **s;
int m = descriptor.m;
int n = descriptor.n;
int ldb = descriptor.ldb;
const T* dl = (const T*)(buffers[0]);
const T* d = (const T*)(buffers[1]);
const T* du = (const T*)(buffers[2]);
const T* B = (T*)(buffers[3]);
T* X = (T*)(buffers[4]);
void* buffer = buffers[5];
// The solution X is written in place to B. We need to therefore copy the
// contents of B into the output buffer X and pass that into the kernel as B.
// Once copy insertion is supported for custom call aliasing, we could alias B
// with X and avoid the copy, the code below is written defensively assuming B
// and X might alias, but today we know they will not.
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
if (X != B) {
size_t B_bytes = ldb * n * sizeof(T);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream)));
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer)));
return absl::OkStatus();
}
void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace jax

View File

@ -1,160 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_CUSPARSE_KERNELS_H_
#define JAXLIB_CUSPARSE_KERNELS_H_
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "jaxlib/handle_pool.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
#define JAX_CUSPARSE_11300 (CUSPARSE_VERSION >= 11300)
// CUDA-11.8 introduces FP8 E4M3/E5M2 types.
#define JAX_CUDA_11080 (CUDA_VERSION >= 11080)
namespace jax {
using SparseHandlePool = HandlePool<cusparseHandle_t, cudaStream_t>;
template <>
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
cudaStream_t stream);
union CudaConst {
int8_t i8[2];
int16_t i16[2];
int32_t i32[2];
int64_t i64[2];
uint8_t u8[2];
uint16_t u16[2];
uint32_t u32[2];
uint64_t u64[2];
float f32[2];
double f64[2];
};
CudaConst CudaZero(cudaDataType type);
CudaConst CudaOne(cudaDataType type);
struct SparseMatDescriptor {
cudaDataType value_type;
cusparseIndexType_t index_type;
int rows, cols, nnz;
int batch_count = 1;
int batch_stride = 0;
};
struct DenseMatDescriptor {
cudaDataType type;
int rows, cols;
int batch_count = 1;
int batch_stride = 0;
};
struct DenseVecDescriptor {
cudaDataType type;
int size;
};
#if JAX_CUSPARSE_11300
// CsrToDense: Convert CSR matrix to dense matrix
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrFromDense: Convert dense matrix to CSR matrix
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrMatvec: Product of CSR matrix and dense vector.
struct CsrMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
cusparseOperation_t op;
};
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrMatmat: Product of CSR matrix and dense matrix.
struct CsrMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
cusparseOperation_t op_A;
};
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooToDense: Convert COO matrix to dense matrix
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooFromDense: Convert dense matrix to COO matrix
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooMatvec: Product of COO matrix and dense vector.
struct CooMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
cusparseOperation_t op;
};
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooMatmat: Product of COO matrix and dense matrix.
struct CooMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
cusparseOperation_t op_A;
};
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // if JAX_CUSPARSE_11300
struct Gtsv2Descriptor {
int m, n, ldb;
};
void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status);
void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_CUSPARSE_KERNELS_H_

43
jaxlib/gpu/BUILD Normal file
View File

@ -0,0 +1,43 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Shared CUDA/ROCM GPU kernels.
licenses(["notice"])
package(default_visibility = ["//:__subpackages__"])
exports_files(srcs = [
"blas.cc",
"blas_kernels.cc",
"blas_kernels.h",
"gpu_kernel_helpers.cc",
"gpu_kernel_helpers.h",
"gpu_kernels.cc",
"linalg.cc",
"lu_pivot_kernels.cc",
"lu_pivot_kernels.cu.cc",
"lu_pivot_kernels.h",
"prng.cc",
"prng_kernels.cc",
"prng_kernels.cu.cc",
"prng_kernels.h",
"solver.cc",
"solver_kernels.cc",
"solver_kernels.h",
"sparse.cc",
"sparse_kernels.cc",
"sparse_kernels.h",
"vendor.h",
])

View File

@ -1,4 +1,4 @@
/* Copyright 2021 The JAX Authors.
/* Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -20,28 +20,27 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/blas_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/rocm/hipblas_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
namespace py = pybind11;
// Converts a NumPy dtype to a Type.
HipblasType DtypeToHipblasType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, HipblasType>({
{{'f', 4}, HipblasType::F32},
{{'f', 8}, HipblasType::F64},
{{'c', 8}, HipblasType::C64},
{{'c', 16}, HipblasType::C128},
});
BlasType DtypeToBlasType(const py::dtype& np_type) {
static auto* types = new absl::flat_hash_map<std::pair<char, int>, BlasType>({
{{'f', 4}, BlasType::F32},
{{'f', 8}, BlasType::F64},
{{'c', 8}, BlasType::C64},
{{'c', 16}, BlasType::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
@ -53,7 +52,7 @@ HipblasType DtypeToHipblasType(const py::dtype& np_type) {
// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
int b, int n) {
HipblasType type = DtypeToHipblasType(dtype);
BlasType type = DtypeToBlasType(dtype);
size_t size = b * sizeof(void*);
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
}
@ -61,23 +60,24 @@ std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
int b, int m, int n) {
HipblasType type = DtypeToHipblasType(dtype);
BlasType type = DtypeToBlasType(dtype);
size_t size = b * sizeof(void*);
return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})};
}
py::dict Registrations() {
py::dict dict;
dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict["hipblas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
return dict;
}
PYBIND11_MODULE(_hipblas, m) {
PYBIND11_MODULE(_blas, m) {
m.def("registrations", &Registrations);
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
}
} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -1,4 +1,4 @@
/* Copyright 2021 The JAX Authors.
/* Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,60 +13,61 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hipblas_kernels.h"
#include "jaxlib/gpu/blas_kernels.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
using BlasHandlePool = HandlePool<hipblasHandle_t, hipStream_t>;
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
template <>
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
hipStream_t stream) {
gpuStream_t stream) {
BlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
hipblasHandle_t handle;
gpublasHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCreate(&handle)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSetStream(handle, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
namespace JAX_GPU_NAMESPACE {
namespace {
// Converts a NumPy dtype to a CublasType.
// Converts a NumPy dtype to a BlasType.
int SizeOfHipblasType(HipblasType type) {
int SizeOfBlasType(BlasType type) {
switch (type) {
case HipblasType::F32:
case BlasType::F32:
return sizeof(float);
case HipblasType::F64:
case BlasType::F64:
return sizeof(double);
case HipblasType::C64:
return sizeof(hipComplex);
case HipblasType::C128:
return sizeof(hipDoubleComplex);
case BlasType::C64:
return sizeof(gpublasComplex);
case BlasType::C128:
return sizeof(gpublasDoubleComplex);
}
}
@ -74,7 +75,7 @@ int SizeOfHipblasType(HipblasType type) {
// Batched LU decomposition: getrfbatched
static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -83,43 +84,43 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n,
gpuMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
SizeOfHipblasType(d.type) * d.n * d.n);
SizeOfBlasType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
switch (d.type) {
case HipblasType::F32: {
case BlasType::F32: {
float** batch_ptrs = static_cast<float**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSgetrfBatched(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case HipblasType::F64: {
case BlasType::F64: {
double** batch_ptrs = static_cast<double**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDgetrfBatched(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case HipblasType::C64: {
hipblasComplex** batch_ptrs = static_cast<hipblasComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCgetrfBatched(
case BlasType::C64: {
gpublasComplex** batch_ptrs = static_cast<gpublasComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case HipblasType::C128: {
hipblasDoubleComplex** batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZgetrfBatched(
case BlasType::C128: {
gpublasDoubleComplex** batch_ptrs =
static_cast<gpublasDoubleComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
@ -127,7 +128,7 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
return absl::OkStatus();
}
void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -138,7 +139,7 @@ void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
// Batched QR decomposition: geqrfbatched
static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GeqrfBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -147,56 +148,56 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n,
gpuMemcpyDeviceToDevice, stream)));
}
std::vector<int> info(d.batch);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfHipblasType(d.type) * d.m * d.n);
SizeOfBlasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
auto tau_ptrs_host =
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfHipblasType(d.type) * std::min(d.m, d.n));
SizeOfBlasType(d.type) * std::min(d.m, d.n));
JAX_RETURN_IF_ERROR(tau_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
switch (d.type) {
case HipblasType::F32: {
case BlasType::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** tau_batch_ptrs = static_cast<float**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipblasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case HipblasType::F64: {
case BlasType::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** tau_batch_ptrs = static_cast<double**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipblasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case HipblasType::C64: {
hipblasComplex** a_batch_ptrs = static_cast<hipblasComplex**>(buffers[3]);
hipblasComplex** tau_batch_ptrs =
static_cast<hipblasComplex**>(buffers[4]);
case BlasType::C64: {
gpublasComplex** a_batch_ptrs = static_cast<gpublasComplex**>(buffers[3]);
gpublasComplex** tau_batch_ptrs =
static_cast<gpublasComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipblasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
case HipblasType::C128: {
hipblasDoubleComplex** a_batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[3]);
hipblasDoubleComplex** tau_batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[4]);
case BlasType::C128: {
gpublasDoubleComplex** a_batch_ptrs =
static_cast<gpublasDoubleComplex**>(buffers[3]);
gpublasDoubleComplex** tau_batch_ptrs =
static_cast<gpublasDoubleComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipblasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
tau_batch_ptrs, info.data(), d.batch)));
break;
}
@ -214,7 +215,7 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
return absl::OkStatus();
}
void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque,
void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -223,4 +224,5 @@ void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque,
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,20 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_CUBLAS_KERNELS_H_
#define JAXLIB_CUBLAS_KERNELS_H_
#ifndef JAXLIB_GPU_BLAS_KERNELS_H_
#define JAXLIB_GPU_BLAS_KERNELS_H_
#include <cstddef>
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
// Set of types known to Cusolver.
enum class CublasType {
enum class BlasType {
F32,
F64,
C64,
@ -36,25 +35,24 @@ enum class CublasType {
// Batched LU decomposition: getrfbatched
struct GetrfBatchedDescriptor {
CublasType type;
BlasType type;
int batch, n;
};
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Batched QR decomposition: geqrfbatched
struct GeqrfBatchedDescriptor {
CublasType type;
BlasType type;
int batch, m, n;
};
void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_CUBLAS_KERNELS_H_
#endif // JAXLIB_GPU_BLAS_KERNELS_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2021 The JAX Authors.
/* Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,17 +13,83 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include <stdexcept>
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
std::string ErrorString(hipError_t error) { return hipGetErrorString(error); }
std::string ErrorString(gpuError_t error) { return gpuGetErrorString(error); }
#ifdef JAX_GPU_CUDA
std::string ErrorString(gpusparseStatus_t status) {
return cusparseGetErrorString(status);
}
std::string ErrorString(gpusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS:
return "cuSolver success.";
case CUSOLVER_STATUS_NOT_INITIALIZED:
return "cuSolver has not been initialized";
case CUSOLVER_STATUS_ALLOC_FAILED:
return "cuSolver allocation failed";
case CUSOLVER_STATUS_INVALID_VALUE:
return "cuSolver invalid value error";
case CUSOLVER_STATUS_ARCH_MISMATCH:
return "cuSolver architecture mismatch error";
case CUSOLVER_STATUS_MAPPING_ERROR:
return "cuSolver mapping error";
case CUSOLVER_STATUS_EXECUTION_FAILED:
return "cuSolver execution failed";
case CUSOLVER_STATUS_INTERNAL_ERROR:
return "cuSolver internal error";
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "cuSolver matrix type not supported error";
case CUSOLVER_STATUS_NOT_SUPPORTED:
return "cuSolver not supported error";
case CUSOLVER_STATUS_ZERO_PIVOT:
return "cuSolver zero pivot error";
case CUSOLVER_STATUS_INVALID_LICENSE:
return "cuSolver invalid license error";
default:
return absl::StrCat("Unknown cuSolver error: ", status);
}
}
std::string ErrorString(gpublasStatus_t status) {
switch (status) {
case CUBLAS_STATUS_SUCCESS:
return "cuBlas success";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "cuBlas has not been initialized";
case CUBLAS_STATUS_ALLOC_FAILED:
return "cuBlas allocation failure";
case CUBLAS_STATUS_INVALID_VALUE:
return "cuBlas invalid value error";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "cuBlas architecture mismatch";
case CUBLAS_STATUS_MAPPING_ERROR:
return "cuBlas mapping error";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "cuBlas execution failed";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "cuBlas internal error";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "cuBlas not supported error";
case CUBLAS_STATUS_LICENSE_ERROR:
return "cuBlas license error";
default:
return "Unknown cuBlas error";
}
}
#else
std::string ErrorString(hipsparseStatus_t status) {
// TODO(reza): check and see if we can use hipify
@ -115,6 +181,8 @@ std::string ErrorString(hipblasStatus_t status) {
}
}
#endif
template <typename T>
std::string ErrorString(T status, const char* file, std::int64_t line,
const char* expr) {
@ -123,37 +191,37 @@ std::string ErrorString(T status, const char* file, std::int64_t line,
}
} // namespace
absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line,
absl::Status AsStatus(gpuError_t error, const char* file, std::int64_t line,
const char* expr) {
if (error != hipSuccess)
if (error != gpuSuccess)
return absl::InternalError(ErrorString(error, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(hipsolverStatus_t status, const char* file,
absl::Status AsStatus(gpusolverStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != HIPSOLVER_STATUS_SUCCESS)
if (status != GPUSOLVER_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(hipsparseStatus_t status, const char* file,
absl::Status AsStatus(gpusparseStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != HIPSPARSE_STATUS_SUCCESS)
if (status != GPUSPARSE_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(hipblasStatus_t status, const char* file,
absl::Status AsStatus(gpublasStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != HIPBLAS_STATUS_SUCCESS)
if (status != GPUBLAS_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<void* []>>
MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(
gpuStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
@ -161,8 +229,10 @@ MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
ptr += batch_elem_size;
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
hipMemcpyHostToDevice, stream)));
gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
gpuMemcpyHostToDevice, stream)));
return std::move(host_ptrs);
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,19 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_
#define JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_
#ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_
#define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "jaxlib/gpu/vendor.h"
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
#define JAX_AS_STATUS(expr) \
jax::JAX_GPU_NAMESPACE::AsStatus(expr, __FILE__, __LINE__, #expr)
#define JAX_THROW_IF_ERROR(expr) \
{ \
@ -40,27 +38,29 @@ limitations under the License.
}
namespace jax {
namespace JAX_GPU_NAMESPACE {
// Used via JAX_AS_STATUS(expr) macro.
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
absl::Status AsStatus(gpuError_t error, const char* file, std::int64_t line,
const char* expr);
absl::Status AsStatus(cusolverStatus_t status, const char* file,
absl::Status AsStatus(gpusolverStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(cusparseStatus_t status, const char* file,
absl::Status AsStatus(gpusparseStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(cublasStatus_t status, const char* file,
absl::Status AsStatus(gpublasStatus_t status, const char* file,
std::int64_t line, const char* expr);
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
// completes.
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(cudaStream_t stream,
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(gpuStream_t stream,
void* buffer,
void* dev_ptrs,
int batch,
int batch_elem_size);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_
#endif // JAXLIB_GPU_GPU_KERNEL_HELPERS_H_

View File

@ -16,21 +16,23 @@ limitations under the License.
// This file is not used by JAX itself, but exists to assist with running
// JAX-generated HLO code from outside of JAX.
#include "jaxlib/cuda/cublas_kernels.h"
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
#include "jaxlib/cuda/cuda_prng_kernels.h"
#include "jaxlib/cuda/cusolver_kernels.h"
#include "jaxlib/cuda/cusparse_kernels.h"
#include "jaxlib/gpu/blas_kernels.h"
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include "jaxlib/gpu/prng_kernels.h"
#include "jaxlib/gpu/solver_kernels.h"
#include "jaxlib/gpu/sparse_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation",
CudaLuPivotsToPermutation, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_threefry2x32", CudaThreeFry2x32,
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_lu_pivots_to_permutation",
LuPivotsToPermutation, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
"CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_potrf", Potrf, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
@ -66,4 +68,5 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64,
"CUDA");
} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
std::string BuildCudaLuPivotsToPermutationDescriptor(
std::string BuildLuPivotsToPermutationDescriptor(
std::int64_t batch_size, std::int32_t pivot_size,
std::int32_t permutation_size) {
return PackDescriptorAsString(LuPivotsToPermutationDescriptor{
@ -31,21 +31,22 @@ std::string BuildCudaLuPivotsToPermutationDescriptor(
pybind11::dict Registrations() {
pybind11::dict dict;
dict["cuda_lu_pivots_to_permutation"] =
EncapsulateFunction(CudaLuPivotsToPermutation);
dict[JAX_GPU_PREFIX "_lu_pivots_to_permutation"] =
EncapsulateFunction(LuPivotsToPermutation);
return dict;
}
PYBIND11_MODULE(_cuda_linalg, m) {
PYBIND11_MODULE(_linalg, m) {
m.def("registrations", &Registrations);
m.def("lu_pivots_to_permutation_descriptor",
[](std::int64_t batch_size, std::int32_t pivot_size,
std::int32_t permutation_size) {
std::string result = BuildCudaLuPivotsToPermutationDescriptor(
std::string result = BuildLuPivotsToPermutationDescriptor(
batch_size, pivot_size, permutation_size);
return pybind11::bytes(result);
});
}
} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,38 +13,41 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hip_lu_pivot_kernels.h"
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include <string_view>
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
absl::Status HipLuPivotsToPermutation_(hipStream_t stream, void** buffers,
const char* opaque,
std::size_t opaque_len) {
absl::Status LuPivotsToPermutation_(gpuStream_t stream, void** buffers,
const char* opaque,
std::size_t opaque_len) {
auto s =
UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
LaunchLuPivotsToPermutationKernel(stream, buffers, **s);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError()));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
return absl::OkStatus();
}
} // namespace
void HipLuPivotsToPermutation(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status) {
auto s = HipLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
void LuPivotsToPermutation(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status) {
auto s = LuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
std::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hip_lu_pivot_kernels.h"
#include "jaxlib/gpu/lu_pivot_kernels.h"
#include <array>
#include <iostream>
#include "jaxlib/gpu/vendor.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
__device__ void ComputePermutation(const std::int32_t* pivots,
@ -58,7 +61,7 @@ __global__ void LuPivotsToPermutationKernel(
} // namespace
void LaunchLuPivotsToPermutationKernel(
hipStream_t stream, void** buffers,
gpuStream_t stream, void** buffers,
LuPivotsToPermutationDescriptor descriptor) {
const std::int32_t* pivots =
reinterpret_cast<const std::int32_t*>(buffers[0]);
@ -74,4 +77,5 @@ void LaunchLuPivotsToPermutationKernel(
descriptor.permutation_size);
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,16 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIP_LU_PIVOT_KERNELS_H_
#define JAXLIB_HIP_LU_PIVOT_KERNELS_H_
#ifndef JAXLIB_GPU_LU_PIVOT_KERNELS_H_
#define JAXLIB_GPU_LU_PIVOT_KERNELS_H_
#include <cstddef>
#include <string>
#include "rocm/include/hip/hip_runtime_api.h"
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
struct LuPivotsToPermutationDescriptor {
std::int64_t batch_size;
@ -31,13 +32,14 @@ struct LuPivotsToPermutationDescriptor {
};
void LaunchLuPivotsToPermutationKernel(
hipStream_t stream, void** buffers,
gpuStream_t stream, void** buffers,
LuPivotsToPermutationDescriptor descriptor);
void HipLuPivotsToPermutation(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status);
void LuPivotsToPermutation(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_HIP_LU_PIVOT_KERNELS_H_
#endif // JAXLIB_GPU_LU_PIVOT_KERNELS_H_

View File

@ -13,31 +13,32 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cuda_prng_kernels.h"
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/prng_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
std::string BuildThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["cuda_threefry2x32"] = EncapsulateFunction(CudaThreeFry2x32);
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
return dict;
}
PYBIND11_MODULE(_cuda_prng, m) {
PYBIND11_MODULE(_prng, m) {
m.def("registrations", &Registrations);
m.def("threefry2x32_descriptor", [](std::int64_t n) {
std::string result = BuildCudaThreeFry2x32Descriptor(n);
std::string result = BuildThreeFry2x32Descriptor(n);
return pybind11::bytes(result);
});
}
} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,35 +13,37 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cuda_prng_kernels.h"
#include "jaxlib/gpu/prng_kernels.h"
#include <string_view>
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
absl::Status CudaThreeFry2x32_(cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
LaunchThreeFry2x32Kernel(stream, buffers, **s);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
return absl::OkStatus();
}
} // namespace
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len);
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
std::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda/cuda_prng_kernels.h"
#include "jaxlib/gpu/prng_kernels.h"
#include <array>
#include <cstddef>
#include "jaxlib/gpu/vendor.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
__global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
@ -96,7 +99,7 @@ __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
} // namespace
void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers,
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor) {
std::array<const std::uint32_t*, 2> keys;
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
@ -115,4 +118,5 @@ void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers,
out[1], descriptor.n);
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,27 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_CUDA_PRNG_KERNELS_H_
#define JAXLIB_CUDA_PRNG_KERNELS_H_
#ifndef JAXLIB_GPU_PRNG_KERNELS_H_
#define JAXLIB_GPU_PRNG_KERNELS_H_
#include <cstddef>
#include <string>
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
struct ThreeFry2x32Descriptor {
std::int64_t n;
};
void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers,
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor);
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_CUDA_PRNG_KERNELS_H_
#endif // JAXLIB_GPU_PRNG_KERNELS_H_

View File

@ -21,28 +21,27 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
#include "jaxlib/cuda/cusolver_kernels.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/solver_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
namespace py = pybind11;
// Converts a NumPy dtype to a Type.
CusolverType DtypeToCusolverType(const py::dtype& np_type) {
SolverType DtypeToSolverType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, CusolverType>({
{{'f', 4}, CusolverType::F32},
{{'f', 8}, CusolverType::F64},
{{'c', 8}, CusolverType::C64},
{{'c', 16}, CusolverType::C128},
new absl::flat_hash_map<std::pair<char, int>, SolverType>({
{{'f', 4}, SolverType::F32},
{{'f', 8}, SolverType::F64},
{{'c', 8}, SolverType::C64},
{{'c', 16}, SolverType::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
@ -57,48 +56,87 @@ CusolverType DtypeToCusolverType(const py::dtype& np_type) {
// Returns the workspace size and a descriptor for a potrf operation.
std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
std::int64_t workspace_size;
cublasFillMode_t uplo =
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
if (b == 1) {
switch (type) {
case CusolverType::F32:
case SolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnSpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
JAX_AS_STATUS(gpusolverDnSpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(float);
break;
case CusolverType::F64:
case SolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnDpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
JAX_AS_STATUS(gpusolverDnDpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(double);
break;
case CusolverType::C64:
case SolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnCpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(cuComplex);
JAX_AS_STATUS(gpusolverDnCpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(gpuComplex);
break;
case CusolverType::C128:
case SolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnZpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(cuDoubleComplex);
JAX_AS_STATUS(gpusolverDnZpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(gpuDoubleComplex);
break;
}
} else {
#ifdef JAX_GPU_CUDA
// We use the workspace buffer for our own scratch space.
workspace_size = sizeof(void*) * b;
#else
// TODO(rocm): when cuda and hip had same API for batched potrf, remove this
// batched potrf has different API compared to CUDA. In hip we still need to
// create the workspace and additional space to copy the batch array
// pointers
switch (type) {
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverSpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(float)) + (b * sizeof(float*));
break;
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverDpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(double)) + (b * sizeof(double*));
break;
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverCpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size =
(lwork * sizeof(hipComplex)) + (b * sizeof(hipComplex*));
break;
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverZpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(hipDoubleComplex)) +
(b * sizeof(hipDoubleComplex*));
break;
}
#endif // JAX_GPU_CUDA
}
return {workspace_size,
PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})};
@ -109,38 +147,38 @@ std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
// Returns the workspace size and a descriptor for a getrf operation.
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case CusolverType::F32:
case SolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnSgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case CusolverType::F64:
case SolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnDgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case CusolverType::C64:
case SolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnCgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case CusolverType::C128:
case SolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnZgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
}
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})};
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})};
}
// geqrf: QR decomposition
@ -148,91 +186,91 @@ std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
// Returns the workspace size and a descriptor for a geqrf operation.
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case CusolverType::F32:
case SolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnSgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case CusolverType::F64:
case SolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnDgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case CusolverType::C64:
case SolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnCgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case CusolverType::C128:
case SolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnZgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
}
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})};
}
#ifdef JAX_GPU_CUDA
// csrlsvqr: Linear system solve via Sparse QR
// Returns a descriptor for a csrlsvqr operation.
py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA,
int reorder, double tol) {
CusolverType type = DtypeToCusolverType(dtype);
auto h = SpSolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SolverType type = DtypeToSolverType(dtype);
return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol});
}
#endif // JAX_GPU_CUDA
// orgqr/ungqr: apply elementary Householder transformations
// Returns the workspace size and a descriptor for a geqrf operation.
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
int m, int n, int k) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case CusolverType::F32:
case SolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnSorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case CusolverType::F64:
case SolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnDorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case CusolverType::C64:
case SolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnCungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case CusolverType::C128:
case SolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusolverDnZungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
}
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})};
@ -243,32 +281,32 @@ std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
// Returns the workspace size and a descriptor for a syevd operation.
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
cublasFillMode_t uplo =
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
switch (type) {
case CusolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevd_bufferSize(
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork)));
break;
case CusolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevd_bufferSize(
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork)));
break;
case CusolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevd_bufferSize(
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork)));
break;
case CusolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevd_bufferSize(
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork)));
break;
@ -282,60 +320,60 @@ std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
// Returns the workspace size and a descriptor for a syevj_batched operation.
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
bool lower, int batch, int n) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
syevjInfo_t params;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(&params)));
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); });
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
cublasFillMode_t uplo =
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
gpuSyevjInfo_t params;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(&params)));
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
if (batch == 1) {
switch (type) {
case CusolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevj_bufferSize(
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case CusolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevj_bufferSize(
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case CusolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevj_bufferSize(
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case CusolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevj_bufferSize(
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
}
} else {
switch (type) {
case CusolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevjBatched_bufferSize(
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case CusolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevjBatched_bufferSize(
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case CusolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevjBatched_bufferSize(
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case CusolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevjBatched_bufferSize(
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
@ -350,29 +388,11 @@ std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
int m, int n, bool compute_uv,
bool full_matrices) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case CusolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork)));
break;
case CusolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork)));
break;
case CusolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork)));
break;
case CusolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork)));
break;
}
signed char jobu, jobvt;
if (compute_uv) {
if (full_matrices) {
@ -383,51 +403,71 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
} else {
jobu = jobvt = 'N';
}
switch (type) {
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd_bufferSize(
handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd_bufferSize(
handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd_bufferSize(
handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd_bufferSize(
handle.get(), jobu, jobvt, m, n, &lwork)));
break;
}
return {lwork,
PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})};
}
#ifdef JAX_GPU_CUDA
// Singular value decomposition using Jacobi algorithm: gesvdj
// Returns the workspace size and a descriptor for a gesvdj operation.
std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
int batch, int m, int n,
bool compute_uv, int econ) {
CusolverType type = DtypeToCusolverType(dtype);
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
cusolverEigMode_t jobz =
compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
gpusolverEigMode_t jobz =
compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
gesvdjInfo_t params;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(&params)));
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
if (batch == 1) {
switch (type) {
case CusolverType::F32:
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize(
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params)));
break;
case CusolverType::F64:
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize(
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params)));
break;
case CusolverType::C64:
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize(
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params)));
break;
case CusolverType::C128:
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize(
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
@ -437,28 +477,28 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
}
} else {
switch (type) {
case CusolverType::F32:
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched_bufferSize(
handle.get(), jobz, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params, batch)));
break;
case CusolverType::F64:
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched_bufferSize(
handle.get(), jobz, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params, batch)));
break;
case CusolverType::C64:
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched_bufferSize(
handle.get(), jobz, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params, batch)));
break;
case CusolverType::C128:
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched_bufferSize(
handle.get(), jobz, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
@ -471,32 +511,40 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})};
}
#endif // JAX_GPU_CUDA
py::dict Registrations() {
py::dict dict;
dict["cusolver_potrf"] = EncapsulateFunction(Potrf);
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
dict["cusolver_geqrf"] = EncapsulateFunction(Geqrf);
dict[JAX_GPU_PREFIX "solver_potrf"] = EncapsulateFunction(Potrf);
dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf);
dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf);
dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr);
dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd);
dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj);
dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd);
#ifdef JAX_GPU_CUDA
dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr);
dict["cusolver_orgqr"] = EncapsulateFunction(Orgqr);
dict["cusolver_syevd"] = EncapsulateFunction(Syevd);
dict["cusolver_syevj"] = EncapsulateFunction(Syevj);
dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd);
dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj);
#endif // JAX_GPU_CUDA
return dict;
}
PYBIND11_MODULE(_cusolver, m) {
PYBIND11_MODULE(_solver, m) {
m.def("registrations", &Registrations);
m.def("build_potrf_descriptor", &BuildPotrfDescriptor);
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
#ifdef JAX_GPU_CUDA
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor);
#endif // JAX_GPU_CUDA
}
} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -17,28 +17,36 @@ limitations under the License.
#define JAXLIB_CUSOLVER_KERNELS_H_
#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusolverSp.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/handle_pool.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"
#endif // JAX_GPU_CUDA
namespace jax {
using SolverHandlePool = HandlePool<cusolverDnHandle_t, cudaStream_t>;
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, cudaStream_t>;
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
template <>
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
cudaStream_t stream);
gpuStream_t stream);
#ifdef JAX_GPU_CUDA
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t>;
template <>
absl::StatusOr<SpSolverHandlePool::Handle> SpSolverHandlePool::Borrow(
cudaStream_t stream);
gpuStream_t stream);
#endif // JAX_GPU_CUDA
namespace JAX_GPU_NAMESPACE {
// Set of types known to Cusolver.
enum class CusolverType {
enum class SolverType {
F32,
F64,
C64,
@ -48,105 +56,114 @@ enum class CusolverType {
// potrf: Cholesky decomposition
struct PotrfDescriptor {
CusolverType type;
cublasFillMode_t uplo;
SolverType type;
gpusolverFillMode_t uplo;
std::int64_t batch, n;
int lwork;
};
void Potrf(cudaStream_t stream, void** buffers, const char* opaque,
void Potrf(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// getrf: LU decomposition
struct GetrfDescriptor {
CusolverType type;
int batch, m, n;
SolverType type;
int batch, m, n, lwork;
};
void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
void Getrf(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// geqrf: QR decomposition
struct GeqrfDescriptor {
CusolverType type;
SolverType type;
int batch, m, n, lwork;
};
void Geqrf(cudaStream_t stream, void** buffers, const char* opaque,
void Geqrf(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#ifdef JAX_GPU_CUDA
// csrlsvpr: Linear system solve via Sparse QR
struct CsrlsvqrDescriptor {
CusolverType type;
SolverType type;
int n, nnz, reorder;
double tol;
};
void Csrlsvqr(cudaStream_t stream, void** buffers, const char* opaque,
void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // JAX_GPU_CUDA
// orgqr/ungqr: apply elementary Householder transformations
struct OrgqrDescriptor {
CusolverType type;
SolverType type;
int batch, m, n, k, lwork;
};
void Orgqr(cudaStream_t stream, void** buffers, const char* opaque,
void Orgqr(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
struct SyevdDescriptor {
CusolverType type;
cublasFillMode_t uplo;
SolverType type;
gpusolverFillMode_t uplo;
int batch, n;
int lwork;
};
void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
void Syevd(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
struct SyevjDescriptor {
CusolverType type;
cublasFillMode_t uplo;
SolverType type;
gpusolverFillMode_t uplo;
int batch, n;
int lwork;
};
void Syevj(cudaStream_t stream, void** buffers, const char* opaque,
void Syevj(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Singular value decomposition using QR algorithm: gesvd
struct GesvdDescriptor {
CusolverType type;
SolverType type;
int batch, m, n;
int lwork;
signed char jobu, jobvt;
};
void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
void Gesvd(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#ifdef JAX_GPU_CUDA
// Singular value decomposition using Jacobi algorithm: gesvdj
struct GesvdjDescriptor {
CusolverType type;
SolverType type;
int batch, m, n;
int lwork;
cusolverEigMode_t jobz;
gpusolverEigMode_t jobz;
int econ;
};
void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // JAX_GPU_CUDA
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_CUSOLVER_KERNELS_H_

View File

@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "rocm/include/hipsparse.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
@ -26,10 +24,9 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "rocm/include/hip/hip_complex.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/rocm/hipsparse_kernels.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/sparse_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
@ -38,14 +35,15 @@ limitations under the License.
namespace py = pybind11;
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
hipsparseIndexType_t DtypeToHipSparseIndexType(const py::dtype& np_type) {
gpusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, hipsparseIndexType_t>({
{{'u', 2}, HIPSPARSE_INDEX_16U},
{{'i', 4}, HIPSPARSE_INDEX_32I},
{{'i', 8}, HIPSPARSE_INDEX_64I},
new absl::flat_hash_map<std::pair<char, int>, gpusparseIndexType_t>({
{{'u', 2}, GPUSPARSE_INDEX_16U},
{{'i', 4}, GPUSPARSE_INDEX_32I},
{{'i', 8}, GPUSPARSE_INDEX_64I},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
@ -55,16 +53,20 @@ hipsparseIndexType_t DtypeToHipSparseIndexType(const py::dtype& np_type) {
return it->second;
}
// TODO(rocm): add more hip data types when supported
hipDataType DtypeToHipDataType(const py::dtype& np_type) {
gpuDataType DtypeToCudaDataType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, hipDataType>(
{{{'f', 2}, HIP_R_16F},
{{'c', 4}, HIP_C_16F},
{{'f', 4}, HIP_R_32F},
{{'c', 8}, HIP_C_32F},
{{'f', 8}, HIP_R_64F},
{{'c', 16}, HIP_C_64F}});
new absl::flat_hash_map<std::pair<char, int>, gpuDataType>({
{{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F},
{{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F},
{{'c', 16}, GPU_C_64F},
#ifdef JAX_GPU_CUDA
{{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U},
{{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U},
#if JAX_GPU_HAVE_SPARSE
{{'V', 2}, CUDA_R_16BF},
#endif // JAX_GPU_HAVE_SPARSE
#endif // JAX_GPU_CUDA
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
@ -78,28 +80,28 @@ SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
int rows, int cols, int nnz,
int batch_count,
int batch_stride) {
hipDataType value_type = DtypeToHipDataType(data_dtype);
hipsparseIndexType_t index_type = DtypeToHipSparseIndexType(index_dtype);
return SparseMatDescriptor{
value_type, index_type, rows, cols, nnz, batch_count, batch_stride};
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
return SparseMatDescriptor{value_type, index_type, rows, cols,
nnz, batch_count, batch_stride};
}
// Returns the descriptor for a Dense matrix.
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
int rows, int cols, int batch_count,
int batch_stride) {
hipDataType value_type = DtypeToHipDataType(data_dtype);
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride};
}
// Returns the descriptor for a Dense vector.
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
int size) {
hipDataType value_type = DtypeToHipDataType(data_dtype);
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
return DenseVecDescriptor{value_type, size};
}
#if JAX_GPU_HAVE_SPARSE
// CsrToDense: Convert CSR matrix to dense matrix
// Returns the descriptor for a Sparse matrix.
@ -111,35 +113,35 @@ std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count*/1, /*batch_stride*/0);
/*batch_count*/ 1, /*batch_stride*/ 0);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
// buffer_size does not reference these pointers, but does error on NULL.
// TODO(jakevdp): check whether this is documented.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT,
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
absl::Status CsrToDense_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
@ -147,28 +149,28 @@ absl::Status CsrToDense_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
gpusparseSparseToDense(handle.get(), mat_a, mat_b,
GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -190,30 +192,30 @@ std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
gpusparseDnMatDescr_t mat_a = 0;
gpusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
absl::Status CsrFromDense_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -222,29 +224,29 @@ absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
gpusparseDnMatDescr_t mat_a = 0;
gpusparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -271,32 +273,32 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnVecDescr_t vec_x = 0;
gpusparseDnVecDescr_t vec_y = 0;
gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
HipConst alpha = HipOne(y.type);
HipConst beta = HipZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize(
SparseConst alpha = ConstOne(y.type);
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
HIPSPARSE_MV_ALG_DEFAULT, &buffer_size)));
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
}
@ -320,35 +322,35 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols,
/*batch_count=*/1, /*batch_stride=*/0);
hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
gpusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, HIPSPARSE_ORDER_ROW)));
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, GPUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, HIPSPARSE_ORDER_ROW)));
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, GPUSPARSE_ORDER_ROW)));
size_t buffer_size;
HipConst alpha = HipOne(C.type);
HipConst beta = HipZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize(
handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
SparseConst alpha = ConstOne(C.type);
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
}
@ -366,26 +368,26 @@ std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT,
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
@ -403,25 +405,25 @@ std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
/*batch_count=*/1, /*batch_stride=*/0);
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
gpusparseDnMatDescr_t mat_a = 0;
gpusparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
@ -444,32 +446,32 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnVecDescr_t vec_x = 0;
gpusparseDnVecDescr_t vec_y = 0;
gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
HipConst alpha = HipOne(y.type);
HipConst beta = HipZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize(
SparseConst alpha = ConstOne(y.type);
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
HIPSPARSE_MV_ALG_DEFAULT, &buffer_size)));
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
}
@ -490,63 +492,61 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
batch_count, lhs_batch_stride);
DenseMatDescriptor B =
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols,
batch_count, rhs_batch_stride);
SparseMatDescriptor A = BuildSparseMatDescriptor(
data_dtype, index_dtype, rows, cols, nnz, batch_count, lhs_batch_stride);
DenseMatDescriptor B = BuildDenseMatDescriptor(
b_dtype, transpose ? rows : cols, BCcols, batch_count, rhs_batch_stride);
int C_rows = (transpose == true) ? cols : rows;
// TODO(tianjianlu): enable the selection of batch stride.
// The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
// The issue
// (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
// in cusparse library does not allow batch_stride = 0.
// int C_batch_stride = (batch_count > 1)? C_rows * BCcols : 0;
int C_batch_stride = C_rows * BCcols;
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, /*rows=*/C_rows, /*cols=*/BCcols,
batch_count, C_batch_stride);
hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
gpusparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCooSetStridedBatch(
mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCooSetStridedBatch(
mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseDnMatSetStridedBatch(
mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride)));
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, GPUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseDnMatSetStridedBatch(
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, GPUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
size_t buffer_size;
HipConst alpha = HipOne(C.type);
HipConst beta = HipZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize(
handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
SparseConst alpha = ConstOne(C.type);
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
}
#endif // if JAX_GPU_HAVE_SPARSE
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
@ -565,32 +565,37 @@ size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
}
size_t Gtsv2BufferSizeF32(int m, int n, int ldb) {
return Gtsv2BufferSize(hipsparseSgtsv2_bufferSizeExt, m, n, ldb);
return Gtsv2BufferSize(gpusparseSgtsv2_bufferSizeExt, m, n, ldb);
}
size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
return Gtsv2BufferSize(hipsparseDgtsv2_bufferSizeExt, m, n, ldb);
return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb);
}
py::dict Registrations() {
py::dict dict;
dict["hipsparse_csr_todense"] = EncapsulateFunction(CsrToDense);
dict["hipsparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
dict["hipsparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
dict["hipsparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
dict["hipsparse_coo_todense"] = EncapsulateFunction(CooToDense);
dict["hipsparse_coo_fromdense"] = EncapsulateFunction(CooFromDense);
dict["hipsparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
dict["hipsparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
dict["hipsparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
dict["hipsparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
#if JAX_GPU_HAVE_SPARSE
dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense);
dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] =
EncapsulateFunction(CsrFromDense);
dict[JAX_GPU_PREFIX "sparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
dict[JAX_GPU_PREFIX "sparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
dict[JAX_GPU_PREFIX "sparse_coo_todense"] = EncapsulateFunction(CooToDense);
dict[JAX_GPU_PREFIX "sparse_coo_fromdense"] =
EncapsulateFunction(CooFromDense);
dict[JAX_GPU_PREFIX "sparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
dict[JAX_GPU_PREFIX "sparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
#endif
dict[JAX_GPU_PREFIX "sparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
dict[JAX_GPU_PREFIX "sparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
return dict;
}
PYBIND11_MODULE(_hipsparse, m) {
m.attr("hipsparse_supported") = py::bool_(true);
PYBIND11_MODULE(_sparse, m) {
m.attr("sparse_supported") = py::bool_(JAX_GPU_HAVE_SPARSE);
m.def("registrations", &Registrations);
#if JAX_GPU_HAVE_SPARSE
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
@ -599,10 +604,12 @@ PYBIND11_MODULE(_hipsparse, m) {
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
#endif
m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32);
m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64);
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
}
} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hipsparse_kernels.h"
#include "jaxlib/gpu/sparse_kernels.h"
#include <algorithm>
#include <cstdint>
@ -24,62 +24,128 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "rocm/include/hip/hip_complex.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
template <>
/*static*/ absl::StatusOr<SparseHandlePool::Handle>
SparseHandlePool::Borrow(hipStream_t stream) {
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
gpuStream_t stream) {
SparseHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
hipsparseHandle_t handle;
gpusparseHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreate(&handle)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSetStream(handle, stream)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
HipConst HipZero(hipDataType type) {
HipConst c;
namespace JAX_GPU_NAMESPACE {
SparseConst ConstZero(gpuDataType type) {
SparseConst c;
std::memset(&c, 0, sizeof(c));
return c;
}
HipConst HipOne(hipDataType type) {
HipConst c;
SparseConst ConstOne(gpuDataType type) {
SparseConst c;
std::memset(&c, 0, sizeof(c));
// TODO(rocm): add more data type if new rocm support
switch (type) {
#ifdef JAX_GPU_CUDA
#if JAX_GPU_HAVE_SPARSE
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
case CUDA_R_4I:
case CUDA_C_4I:
#endif
case CUDA_R_8I:
case CUDA_C_8I:
c.i8[0] = 1;
break;
#if JAX_GPU_HAVE_SPARSE
case CUDA_R_4U:
case CUDA_C_4U:
#endif
case CUDA_R_8U:
case CUDA_C_8U:
c.u8[0] = 1;
break;
#if JAX_GPU_HAVE_SPARSE
case CUDA_R_16I:
case CUDA_C_16I:
c.i16[0] = 1;
break;
case CUDA_R_16U:
case CUDA_C_16U:
c.u16[0] = 1;
break;
#endif
case CUDA_R_32I:
case CUDA_C_32I:
c.i32[0] = 1;
break;
case CUDA_R_32U:
case CUDA_C_32U:
c.u32[0] = 1;
break;
#if JAX_GPU_HAVE_SPARSE
case CUDA_R_64I:
case CUDA_C_64I:
c.i64[0] = 1;
break;
case CUDA_R_64U:
case CUDA_C_64U:
c.u64[0] = 1;
break;
#endif
#if JAX_CUDA_11080
case CUDA_R_8F_E4M3:
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E4M3);
break;
case CUDA_R_8F_E5M2:
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E5M2);
break;
#endif
#if JAX_GPU_HAVE_SPARSE
case CUDA_R_16BF:
case CUDA_C_16BF:
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
break;
#endif
#endif // JAX_GPU_CUDA
// TODO(rocm): add more data types if new rocm supports them.
// TODO(jakevdp): 16F/16BF here might break on big endian platforms.
case HIP_R_16F:
case HIP_C_16F:
case GPU_R_16F:
case GPU_C_16F:
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
break;
case HIP_R_32F:
case HIP_C_32F:
case GPU_R_32F:
case GPU_C_32F:
c.f32[0] = 1.0;
break;
case HIP_R_64F:
case HIP_C_64F:
case GPU_R_64F:
case GPU_C_64F:
c.f64[0] = 1.0;
break;
}
return c;
}
static absl::Status CsrToDense_(hipStream_t stream, void** buffers,
#if JAX_GPU_HAVE_SPARSE
// CsrToDense: Convert CSR matrix to dense matrix
static absl::Status CsrToDense_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -88,28 +154,28 @@ static absl::Status CsrToDense_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
gpusparseSparseToDense(handle.get(), mat_a, mat_b,
GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -120,7 +186,7 @@ void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
// CsrFromDense: Convert dense matrix to CSR matrix
static absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -129,29 +195,29 @@ static absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
gpusparseDnMatDescr_t mat_a = 0;
gpusparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -162,7 +228,7 @@ void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
// CsrMatvec: Product of CSR matrix and dense vector.
static absl::Status CsrMatvec_(hipStream_t stream, void** buffers,
static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -178,38 +244,37 @@ static absl::Status CsrMatvec_(hipStream_t stream, void** buffers,
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(rocm): check the following statement for rocm
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.y.type);
HipConst beta = HipZero(d.y.type);
SparseConst alpha = ConstOne(d.y.type);
SparseConst beta = ConstZero(d.y.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnVecDescr_t vec_x = 0;
gpusparseDnVecDescr_t vec_y = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind,
csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO,
csr_values, d.A.index_type, d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO,
d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf)));
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -220,7 +285,7 @@ void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
// CsrMatmat: Product of CSR matrix and dense matrix.
static absl::Status CsrMatmat_(hipStream_t stream, void** buffers,
static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -240,34 +305,34 @@ static absl::Status CsrMatmat_(hipStream_t stream, void** buffers,
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.C.type);
HipConst beta = HipZero(d.C.type);
SparseConst alpha = ConstOne(d.C.type);
SparseConst beta = ConstZero(d.C.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
gpusparseDnMatDescr_t mat_c = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind,
csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO,
csr_values, d.A.index_type, d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO,
d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
/*ld=*/d.B.cols, Bbuf, d.B.type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM(
handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf)));
/*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -278,7 +343,7 @@ void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
// CooToDense: Convert COO matrix to dense matrix
static absl::Status CooToDense_(hipStream_t stream, void** buffers,
static absl::Status CooToDense_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -287,28 +352,28 @@ static absl::Status CooToDense_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
gpusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[1],
/*cooColInd=*/buffers[2],
/*cooValues=*/buffers[0], d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
gpusparseSparseToDense(handle.get(), mat_a, mat_b,
GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
void CooToDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -319,7 +384,7 @@ void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
// CooFromDense: Convert dense matrix to COO matrix
static absl::Status CooFromDense_(hipStream_t stream, void** buffers,
static absl::Status CooFromDense_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -328,29 +393,29 @@ static absl::Status CooFromDense_(hipStream_t stream, void** buffers,
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
gpusparseDnMatDescr_t mat_a = 0;
gpusparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
gpusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[2],
/*cooColInd=*/buffers[3],
/*cooValues=*/buffers[1], d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -361,7 +426,7 @@ void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
// CooMatvec: Product of COO matrix and dense vector.
static absl::Status CooMatvec_(hipStream_t stream, void** buffers,
static absl::Status CooMatvec_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -377,37 +442,36 @@ static absl::Status CooMatvec_(hipStream_t stream, void** buffers,
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(rocm): check the following statement for rocm
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.y.type);
HipConst beta = HipZero(d.y.type);
SparseConst alpha = ConstOne(d.y.type);
SparseConst beta = ConstZero(d.y.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnVecDescr_t vec_x = 0;
gpusparseDnVecDescr_t vec_y = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf)));
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -418,7 +482,7 @@ void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
// CooMatmat: Product of COO matrix and dense matrix.
static absl::Status CooMatmat_(hipStream_t stream, void** buffers,
static absl::Status CooMatmat_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
@ -434,47 +498,46 @@ static absl::Status CooMatmat_(hipStream_t stream, void** buffers,
void* Cbuf = buffers[4];
void* buf = buffers[5];
// TODO(rocm): check the following statement for rocm
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.C.type);
HipConst beta = HipZero(d.C.type);
SparseConst alpha = ConstOne(d.C.type);
SparseConst beta = ConstZero(d.C.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
gpusparseSpMatDescr_t mat_a = 0;
gpusparseDnMatDescr_t mat_b = 0;
gpusparseDnMatDescr_t mat_c = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo(
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
/*batchStride=*/d.A.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
gpusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
/*batchStride=*/d.A.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.B.cols, Bbuf, d.B.type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
/*batchStride=*/d.B.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
gpusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
/*batchStride=*/d.B.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW)));
/*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
/*batchStride=*/d.C.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM(
handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf)));
gpusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
/*batchStride=*/d.C.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
@ -482,9 +545,10 @@ void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
s.message().length());
}
}
#endif // if JAX_GPU_HAVE_SPARSE
template <typename T, typename F>
static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto h = SparseHandlePool::Borrow();
JAX_RETURN_IF_ERROR(h.status());
@ -513,7 +577,7 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
if (X != B) {
size_t B_bytes = ldb * n * sizeof(T);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(X, B, B_bytes, hipMemcpyDeviceToDevice, stream)));
gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream)));
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
@ -521,22 +585,23 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
return absl::OkStatus();
}
void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque,
void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<float>(hipsparseSgtsv2, stream, buffers, opaque, opaque_len);
auto s = gtsv2<float>(gpusparseSgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque,
void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<double>(hipsparseDgtsv2, stream, buffers, opaque, opaque_len);
auto s = gtsv2<double>(gpusparseDgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIPSPARSE_KERNELS_H_
#define JAXLIB_HIPSPARSE_KERNELS_H_
#ifndef JAXLIB_GPU_SPARSE_KERNELS_H_
#define JAXLIB_GPU_SPARSE_KERNELS_H_
#include <algorithm>
#include <cstdint>
@ -23,23 +23,21 @@ limitations under the License.
#include <vector>
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/handle_pool.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipsparse.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
namespace jax {
using SparseHandlePool = HandlePool<hipsparseHandle_t, hipStream_t>;
using SparseHandlePool = HandlePool<gpusparseHandle_t, gpuStream_t>;
template <>
/*static*/ absl::StatusOr<SparseHandlePool::Handle>
SparseHandlePool::Borrow(hipStream_t stream);
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
gpuStream_t stream);
union HipConst {
namespace JAX_GPU_NAMESPACE {
union SparseConst {
int8_t i8[2];
int16_t i16[2];
int32_t i32[2];
@ -52,37 +50,38 @@ union HipConst {
double f64[2];
};
HipConst HipZero(hipDataType type);
HipConst HipOne(hipDataType type);
SparseConst ConstZero(gpuDataType type);
SparseConst ConstOne(gpuDataType type);
struct SparseMatDescriptor {
hipDataType value_type;
hipsparseIndexType_t index_type;
gpuDataType value_type;
gpusparseIndexType_t index_type;
int rows, cols, nnz;
int batch_count = 1;
int batch_stride = 0;
};
struct DenseMatDescriptor {
hipDataType type;
gpuDataType type;
int rows, cols;
int batch_count = 1;
int batch_stride = 0;
};
struct DenseVecDescriptor {
hipDataType type;
gpuDataType type;
int size;
};
#if JAX_GPU_HAVE_SPARSE
// CsrToDense: Convert CSR matrix to dense matrix
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrFromDense: Convert dense matrix to CSR matrix
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrMatvec: Product of CSR matrix and dense vector.
@ -90,10 +89,10 @@ void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
struct CsrMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
hipsparseOperation_t op;
gpusparseOperation_t op;
};
void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrMatmat: Product of CSR matrix and dense matrix.
@ -101,20 +100,20 @@ void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
struct CsrMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
hipsparseOperation_t op_A;
gpusparseOperation_t op_A;
};
void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooToDense: Convert COO matrix to dense matrix
void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
void CooToDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooFromDense: Convert dense matrix to COO matrix
void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooMatvec: Product of COO matrix and dense vector.
@ -122,10 +121,10 @@ void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
struct CooMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
hipsparseOperation_t op;
gpusparseOperation_t op;
};
void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooMatmat: Product of COO matrix and dense matrix.
@ -133,22 +132,24 @@ void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
struct CooMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
hipsparseOperation_t op_A;
gpusparseOperation_t op_A;
};
void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // JAX_GPU_HAVE_SPARSE
struct Gtsv2Descriptor {
int m, n, ldb;
};
void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque,
void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status);
void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque,
void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_HIPSPARSE_KERNELS_H_
#endif // JAXLIB_GPU_SPARSE_KERNELS_H_

442
jaxlib/gpu/vendor.h Normal file
View File

@ -0,0 +1,442 @@
/* Copyright 2022 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header is a shim that manages differences between CUDA and ROCM APIs.
// Jaxlib GPU kernels can be compiled for either CUDA or ROCM by defining
// JAX_GPU_CUDA or JAX_GPU_HIP respectively.
#ifndef JAXLIB_GPU_VENDOR_H_
#define JAXLIB_GPU_VENDOR_H_
#if defined(JAX_GPU_CUDA)
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusparse.h"
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
// CUDA-11.8 introduces FP8 E4M3/E5M2 types.
#define JAX_GPU_HAVE_FP8 (CUDA_VERSION >= 11080)
#if JAX_GPU_HAVE_FP8
#include "third_party/gpus/cuda/include/cuda_fp8.h"
#endif
// cuSPARSE generic APIs are not supported on Windows until 11.0
// cusparseIndexType_t is used in very limited scope so manually define will
// workaround compiling issue without harm.
#if defined(_WIN32) && (CUSPARSE_VERSION < 11000)
typedef enum {
CUSPARSE_INDEX_16U = 1,
CUSPARSE_INDEX_32I = 2,
CUSPARSE_INDEX_64I = 3
} cusparseIndexType_t;
#endif
#define JAX_GPU_NAMESPACE cuda
#define JAX_GPU_PREFIX "cu"
typedef cuComplex gpuComplex;
typedef cuDoubleComplex gpuDoubleComplex;
typedef cuComplex gpublasComplex;
typedef cuDoubleComplex gpublasDoubleComplex;
typedef cublasFillMode_t gpusolverFillMode_t;
typedef cublasStatus_t gpublasStatus_t;
typedef cublasHandle_t gpublasHandle_t;
typedef cudaDataType gpuDataType;
typedef cudaStream_t gpuStream_t;
typedef cudaError_t gpuError_t;
typedef cusolverDnHandle_t gpusolverDnHandle_t;
typedef cusolverStatus_t gpusolverStatus_t;
typedef cusolverEigMode_t gpusolverEigMode_t;
typedef syevjInfo gpuSyevjInfo;
typedef syevjInfo_t gpuSyevjInfo_t;
typedef cusparseIndexType_t gpusparseIndexType_t;
typedef cusparseHandle_t gpusparseHandle_t;
typedef cusparseOperation_t gpusparseOperation_t;
typedef cusparseStatus_t gpusparseStatus_t;
typedef cusparseSpMatDescr_t gpusparseSpMatDescr_t;
typedef cusparseDnMatDescr_t gpusparseDnMatDescr_t;
typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPU_C_16F CUDA_C_16F
#define GPU_R_16F CUDA_R_16F
#define GPU_C_32F CUDA_C_32F
#define GPU_R_32F CUDA_R_32F
#define GPU_C_64F CUDA_C_64F
#define GPU_R_64F CUDA_R_64F
#define gpublasCreate cublasCreate
#define gpublasSetStream cublasSetStream
#define gpublasSgeqrfBatched cublasSgeqrfBatched
#define gpublasDgeqrfBatched cublasDgeqrfBatched
#define gpublasCgeqrfBatched cublasCgeqrfBatched
#define gpublasZgeqrfBatched cublasZgeqrfBatched
#define gpublasSgetrfBatched cublasSgetrfBatched
#define gpublasDgetrfBatched cublasDgetrfBatched
#define gpublasCgetrfBatched cublasCgetrfBatched
#define gpublasZgetrfBatched cublasZgetrfBatched
#define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS
#define gpusolverDnCreate cusolverDnCreate
#define gpusolverDnSetStream cusolverDnSetStream
#define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo
#define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo
#define gpusolverDnSpotrf cusolverDnSpotrf
#define gpusolverDnDpotrf cusolverDnDpotrf
#define gpusolverDnCpotrf cusolverDnCpotrf
#define gpusolverDnZpotrf cusolverDnZpotrf
#define gpusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
cusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
#define gpusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
cusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
#define gpusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
cusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
#define gpusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
cusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
#define gpusolverDnSpotrf_bufferSize cusolverDnSpotrf_bufferSize
#define gpusolverDnDpotrf_bufferSize cusolverDnDpotrf_bufferSize
#define gpusolverDnCpotrf_bufferSize cusolverDnCpotrf_bufferSize
#define gpusolverDnZpotrf_bufferSize cusolverDnZpotrf_bufferSize
#define gpusolverDnSgeqrf cusolverDnSgeqrf
#define gpusolverDnDgeqrf cusolverDnDgeqrf
#define gpusolverDnCgeqrf cusolverDnCgeqrf
#define gpusolverDnZgeqrf cusolverDnZgeqrf
#define gpusolverDnSgeqrf_bufferSize cusolverDnSgeqrf_bufferSize
#define gpusolverDnDgeqrf_bufferSize cusolverDnDgeqrf_bufferSize
#define gpusolverDnCgeqrf_bufferSize cusolverDnCgeqrf_bufferSize
#define gpusolverDnZgeqrf_bufferSize cusolverDnZgeqrf_bufferSize
#define gpusolverDnSgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
cusolverDnSgetrf(h, m, n, a, lda, work, ipiv, info)
#define gpusolverDnDgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
cusolverDnDgetrf(h, m, n, a, lda, work, ipiv, info)
#define gpusolverDnCgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
cusolverDnCgetrf(h, m, n, a, lda, work, ipiv, info)
#define gpusolverDnZgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
cusolverDnZgetrf(h, m, n, a, lda, work, ipiv, info)
#define gpusolverDnSgetrf_bufferSize cusolverDnSgetrf_bufferSize
#define gpusolverDnDgetrf_bufferSize cusolverDnDgetrf_bufferSize
#define gpusolverDnCgetrf_bufferSize cusolverDnCgetrf_bufferSize
#define gpusolverDnZgetrf_bufferSize cusolverDnZgetrf_bufferSize
#define gpusolverDnSorgqr cusolverDnSorgqr
#define gpusolverDnDorgqr cusolverDnDorgqr
#define gpusolverDnCungqr cusolverDnCungqr
#define gpusolverDnZungqr cusolverDnZungqr
#define gpusolverDnSorgqr_bufferSize cusolverDnSorgqr_bufferSize
#define gpusolverDnDorgqr_bufferSize cusolverDnDorgqr_bufferSize
#define gpusolverDnCungqr_bufferSize cusolverDnCungqr_bufferSize
#define gpusolverDnZungqr_bufferSize cusolverDnZungqr_bufferSize
#define gpusolverDnSsyevd cusolverDnSsyevd
#define gpusolverDnDsyevd cusolverDnDsyevd
#define gpusolverDnCheevd cusolverDnCheevd
#define gpusolverDnZheevd cusolverDnZheevd
#define gpusolverDnSsyevd_bufferSize cusolverDnSsyevd_bufferSize
#define gpusolverDnDsyevd_bufferSize cusolverDnDsyevd_bufferSize
#define gpusolverDnCheevd_bufferSize cusolverDnCheevd_bufferSize
#define gpusolverDnZheevd_bufferSize cusolverDnZheevd_bufferSize
#define gpusolverDnSsyevj cusolverDnSsyevj
#define gpusolverDnDsyevj cusolverDnDsyevj
#define gpusolverDnCheevj cusolverDnCheevj
#define gpusolverDnZheevj cusolverDnZheevj
#define gpusolverDnSsyevj_bufferSize cusolverDnSsyevj_bufferSize
#define gpusolverDnDsyevj_bufferSize cusolverDnDsyevj_bufferSize
#define gpusolverDnCheevj_bufferSize cusolverDnCheevj_bufferSize
#define gpusolverDnZheevj_bufferSize cusolverDnZheevj_bufferSize
#define gpusolverDnSsyevjBatched cusolverDnSsyevjBatched
#define gpusolverDnDsyevjBatched cusolverDnDsyevjBatched
#define gpusolverDnCheevjBatched cusolverDnCheevjBatched
#define gpusolverDnZheevjBatched cusolverDnZheevjBatched
#define gpusolverDnSsyevjBatched_bufferSize cusolverDnSsyevjBatched_bufferSize
#define gpusolverDnDsyevjBatched_bufferSize cusolverDnDsyevjBatched_bufferSize
#define gpusolverDnCheevjBatched_bufferSize cusolverDnCheevjBatched_bufferSize
#define gpusolverDnZheevjBatched_bufferSize cusolverDnZheevjBatched_bufferSize
#define gpusolverDnSgesvd cusolverDnSgesvd
#define gpusolverDnDgesvd cusolverDnDgesvd
#define gpusolverDnCgesvd cusolverDnCgesvd
#define gpusolverDnZgesvd cusolverDnZgesvd
#define gpusolverDnSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
cusolverDnSgesvd_bufferSize(h, m, n, lwork)
#define gpusolverDnDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
cusolverDnDgesvd_bufferSize(h, m, n, lwork)
#define gpusolverDnCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
cusolverDnCgesvd_bufferSize(h, m, n, lwork)
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
cusolverDnZgesvd_bufferSize(h, m, n, lwork)
#define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER
#define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER
#define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR
#define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS
#define gpusparseCooSetStridedBatch cusparseCooSetStridedBatch
#define gpusparseCreate cusparseCreate
#define gpusparseCreateCoo cusparseCreateCoo
#define gpusparseCreateCsr cusparseCreateCsr
#define gpusparseCreateDnMat cusparseCreateDnMat
#define gpusparseCreateDnVec cusparseCreateDnVec
#define gpusparseDenseToSparse_analysis cusparseDenseToSparse_analysis
#define gpusparseDenseToSparse_bufferSize cusparseDenseToSparse_bufferSize
#define gpusparseDenseToSparse_convert cusparseDenseToSparse_convert
#define gpusparseDestroySpMat cusparseDestroySpMat
#define gpusparseDestroyDnMat cusparseDestroyDnMat
#define gpusparseDestroyDnVec cusparseDestroyDnVec
#define gpusparseDnMatSetStridedBatch cusparseDnMatSetStridedBatch
#define gpusparseSetStream cusparseSetStream
#define gpusparseSparseToDense cusparseSparseToDense
#define gpusparseSparseToDense_bufferSize cusparseSparseToDense_bufferSize
#define gpusparseSpMM cusparseSpMM
#define gpusparseSpMM_bufferSize cusparseSpMM_bufferSize
#define gpusparseSpMV cusparseSpMV
#define gpusparseSpMV_bufferSize cusparseSpMV_bufferSize
#define gpusparseSgtsv2 cusparseSgtsv2
#define gpusparseDgtsv2 cusparseDgtsv2
#define gpusparseSgtsv2_bufferSizeExt cusparseSgtsv2_bufferSizeExt
#define gpusparseDgtsv2_bufferSizeExt cusparseDgtsv2_bufferSizeExt
#define GPUSPARSE_INDEX_16U CUSPARSE_INDEX_16U
#define GPUSPARSE_INDEX_32I CUSPARSE_INDEX_32I
#define GPUSPARSE_INDEX_64I CUSPARSE_INDEX_64I
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
#define GPUSPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO
#define GPUSPARSE_MV_ALG_DEFAULT CUSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS
#define gpuGetLastError cudaGetLastError
#define gpuGetErrorString cudaGetErrorString
#define gpuMemcpyAsync cudaMemcpyAsync
#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
#define gpuStreamSynchronize cudaStreamSynchronize
#define gpuSuccess cudaSuccess
#elif defined(JAX_GPU_HIP)
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "rocm/include/hipsolver.h"
#include "rocm/include/hipsparse.h"
#define JAX_GPU_NAMESPACE hip
#define JAX_GPU_PREFIX "hip"
#define JAX_GPU_HAVE_SPARSE 1
#define JAX_GPU_HAVE_FP8 0
typedef hipFloatComplex gpuComplex;
typedef hipDoubleComplex gpuDoubleComplex;
typedef hipblasComplex gpublasComplex;
typedef hipblasDoubleComplex gpublasDoubleComplex;
typedef hipsolverHandle_t gpusolverDnHandle_t;
typedef hipblasFillMode_t gpublasFillMode_t;
typedef hipsolverFillMode_t gpusolverFillMode_t;
typedef hipblasHandle_t gpublasHandle_t;
typedef hipblasStatus_t gpublasStatus_t;
typedef hipDataType gpuDataType;
typedef hipStream_t gpuStream_t;
typedef hipError_t gpuError_t;
typedef void gpuSyevjInfo;
typedef hipsolverSyevjInfo_t gpuSyevjInfo_t;
typedef hipsolverEigMode_t gpusolverEigMode_t;
typedef hipsolverStatus_t gpusolverStatus_t;
typedef hipsparseIndexType_t gpusparseIndexType_t;
typedef hipsparseHandle_t gpusparseHandle_t;
typedef hipsparseOperation_t gpusparseOperation_t;
typedef hipsparseStatus_t gpusparseStatus_t;
typedef hipsparseSpMatDescr_t gpusparseSpMatDescr_t;
typedef hipsparseDnMatDescr_t gpusparseDnMatDescr_t;
typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPU_C_16F HIP_C_16F
#define GPU_R_16F HIP_R_16F
#define GPU_C_32F HIP_C_32F
#define GPU_R_32F HIP_R_32F
#define GPU_C_64F HIP_C_64F
#define GPU_R_64F HIP_R_64F
#define gpublasCreate hipblasCreate
#define gpublasSetStream hipblasSetStream
#define gpublasSgeqrfBatched hipblasSgeqrfBatched
#define gpublasDgeqrfBatched hipblasDgeqrfBatched
#define gpublasCgeqrfBatched hipblasCgeqrfBatched
#define gpublasZgeqrfBatched hipblasZgeqrfBatched
#define gpublasSgetrfBatched hipblasSgetrfBatched
#define gpublasDgetrfBatched hipblasDgetrfBatched
#define gpublasCgetrfBatched hipblasCgetrfBatched
#define gpublasZgetrfBatched hipblasZgetrfBatched
#define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define gpusolverDnCreate hipsolverCreate
#define gpusolverDnSetStream hipsolverSetStream
#define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo
#define gpusolverDnDestroySyevjInfo hipsolverDestroySyevjInfo
#define gpusolverDnSpotrf hipsolverSpotrf
#define gpusolverDnDpotrf hipsolverDpotrf
#define gpusolverDnCpotrf hipsolverCpotrf
#define gpusolverDnZpotrf hipsolverZpotrf
#define gpusolverDnSpotrf_bufferSize hipsolverSpotrf_bufferSize
#define gpusolverDnDpotrf_bufferSize hipsolverDpotrf_bufferSize
#define gpusolverDnCpotrf_bufferSize hipsolverCpotrf_bufferSize
#define gpusolverDnZpotrf_bufferSize hipsolverZpotrf_bufferSize
#define gpusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
hipsolverSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
#define gpusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
hipsolverDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
#define gpusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
hipsolverCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
#define gpusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
batch) \
hipsolverZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
#define gpusolverDnSgeqrf hipsolverSgeqrf
#define gpusolverDnDgeqrf hipsolverDgeqrf
#define gpusolverDnCgeqrf hipsolverCgeqrf
#define gpusolverDnZgeqrf hipsolverZgeqrf
#define gpusolverDnSgeqrf_bufferSize hipsolverSgeqrf_bufferSize
#define gpusolverDnDgeqrf_bufferSize hipsolverDgeqrf_bufferSize
#define gpusolverDnCgeqrf_bufferSize hipsolverCgeqrf_bufferSize
#define gpusolverDnZgeqrf_bufferSize hipsolverZgeqrf_bufferSize
#define gpusolverDnSgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
hipsolverSgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
#define gpusolverDnDgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
hipsolverDgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
#define gpusolverDnCgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
hipsolverCgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
#define gpusolverDnZgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
hipsolverZgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
#define gpusolverDnSgetrf_bufferSize hipsolverSgetrf_bufferSize
#define gpusolverDnDgetrf_bufferSize hipsolverDgetrf_bufferSize
#define gpusolverDnCgetrf_bufferSize hipsolverCgetrf_bufferSize
#define gpusolverDnZgetrf_bufferSize hipsolverZgetrf_bufferSize
#define gpusolverDnSorgqr hipsolverSorgqr
#define gpusolverDnDorgqr hipsolverDorgqr
#define gpusolverDnCungqr hipsolverCungqr
#define gpusolverDnZungqr hipsolverZungqr
#define gpusolverDnSorgqr_bufferSize hipsolverSorgqr_bufferSize
#define gpusolverDnDorgqr_bufferSize hipsolverDorgqr_bufferSize
#define gpusolverDnCungqr_bufferSize hipsolverCungqr_bufferSize
#define gpusolverDnZungqr_bufferSize hipsolverZungqr_bufferSize
#define gpusolverDnSsyevd hipsolverSsyevd
#define gpusolverDnDsyevd hipsolverDsyevd
#define gpusolverDnCheevd hipsolverCheevd
#define gpusolverDnZheevd hipsolverZheevd
#define gpusolverDnSsyevd_bufferSize hipsolverSsyevd_bufferSize
#define gpusolverDnDsyevd_bufferSize hipsolverDsyevd_bufferSize
#define gpusolverDnCheevd_bufferSize hipsolverCheevd_bufferSize
#define gpusolverDnZheevd_bufferSize hipsolverZheevd_bufferSize
#define gpusolverDnSsyevj hipsolverSsyevj
#define gpusolverDnDsyevj hipsolverDsyevj
#define gpusolverDnCheevj hipsolverCheevj
#define gpusolverDnZheevj hipsolverZheevj
#define gpusolverDnSsyevj_bufferSize hipsolverSsyevj_bufferSize
#define gpusolverDnDsyevj_bufferSize hipsolverDsyevj_bufferSize
#define gpusolverDnCheevj_bufferSize hipsolverCheevj_bufferSize
#define gpusolverDnZheevj_bufferSize hipsolverZheevj_bufferSize
#define gpusolverDnSsyevjBatched hipsolverSsyevjBatched
#define gpusolverDnDsyevjBatched hipsolverDsyevjBatched
#define gpusolverDnCheevjBatched hipsolverCheevjBatched
#define gpusolverDnZheevjBatched hipsolverZheevjBatched
#define gpusolverDnSsyevjBatched_bufferSize hipsolverSsyevjBatched_bufferSize
#define gpusolverDnDsyevjBatched_bufferSize hipsolverDsyevjBatched_bufferSize
#define gpusolverDnCheevjBatched_bufferSize hipsolverCheevjBatched_bufferSize
#define gpusolverDnZheevjBatched_bufferSize hipsolverZheevjBatched_bufferSize
#define gpusolverDnSgesvd hipsolverSgesvd
#define gpusolverDnDgesvd hipsolverDgesvd
#define gpusolverDnCgesvd hipsolverCgesvd
#define gpusolverDnZgesvd hipsolverZgesvd
#define gpusolverDnSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
hipsolverSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
#define gpusolverDnDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
hipsolverDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
#define gpusolverDnCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
hipsolverCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
hipsolverZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
#define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER
#define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER
#define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR
#define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS
#define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch
#define gpusparseCreate hipsparseCreate
#define gpusparseSetStream hipsparseSetStream
#define gpusparseCreateCoo hipsparseCreateCoo
#define gpusparseCreateCsr hipsparseCreateCsr
#define gpusparseCreateDnMat hipsparseCreateDnMat
#define gpusparseCreateDnVec hipsparseCreateDnVec
#define gpusparseDenseToSparse_analysis hipsparseDenseToSparse_analysis
#define gpusparseDenseToSparse_bufferSize hipsparseDenseToSparse_bufferSize
#define gpusparseDenseToSparse_convert hipsparseDenseToSparse_convert
#define gpusparseDestroySpMat hipsparseDestroySpMat
#define gpusparseDestroyDnMat hipsparseDestroyDnMat
#define gpusparseDestroyDnVec hipsparseDestroyDnVec
#define gpusparseDnMatSetStridedBatch hipsparseDnMatSetStridedBatch
#define gpusparseSparseToDense hipsparseSparseToDense
#define gpusparseSparseToDense_bufferSize hipsparseSparseToDense_bufferSize
#define gpusparseSpMM hipsparseSpMM
#define gpusparseSpMM_bufferSize hipsparseSpMM_bufferSize
#define gpusparseSpMV hipsparseSpMV
#define gpusparseSpMV_bufferSize hipsparseSpMV_bufferSize
#define gpusparseSgtsv2 hipsparseSgtsv2
#define gpusparseDgtsv2 hipsparseDgtsv2
#define gpusparseSgtsv2_bufferSizeExt hipsparseSgtsv2_bufferSizeExt
#define gpusparseDgtsv2_bufferSizeExt hipsparseDgtsv2_bufferSizeExt
#define GPUSPARSE_INDEX_16U HIPSPARSE_INDEX_16U
#define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I
#define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT
#define GPUSPARSE_MV_ALG_DEFAULT HIPSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO
#define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE
#define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE
#define GPUSPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
#define gpuGetLastError hipGetLastError
#define gpuGetErrorString hipGetErrorString
#define gpuMemcpyAsync hipMemcpyAsync
#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
#define gpuStreamSynchronize hipStreamSynchronize
#define gpuSuccess hipSuccess
#else // defined(GPU vendor)
#error "Either JAX_GPU_CUDA or JAX_GPU_HIP must be defined"
#endif // defined(GPU vendor)
#endif // JAXLIB_GPU_VENDOR_H_

View File

@ -23,14 +23,14 @@ from .mhlo_helpers import custom_call
from jaxlib import xla_client
try:
from .cuda import _cuda_linalg
from .cuda import _linalg as _cuda_linalg
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_linalg = None
try:
from .rocm import _hip_linalg
from .rocm import _linalg as _hip_linalg
for _name, _value in _hip_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
@ -65,7 +65,7 @@ def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_
operand_layouts=[pivots_layout],
result_layouts=[permutations_layout])
cuda_lu_pivots_to_permutation = partial(
_lu_pivots_to_permutation_mhlo, "cuda", _cuda_linalg)
cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_mhlo, "cu",
_cuda_linalg)
hip_lu_pivots_to_permutation = partial(
_lu_pivots_to_permutation_mhlo, "hip", _hip_linalg)

View File

@ -25,14 +25,14 @@ from jaxlib import xla_client
from .mhlo_helpers import custom_call
try:
from .cuda import _cuda_prng
from .cuda import _prng as _cuda_prng
for _name, _value in _cuda_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cuda_prng = None
try:
from .rocm import _hip_prng
from .rocm import _prng as _hip_prng
for _name, _value in _hip_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
@ -64,5 +64,6 @@ def _threefry2x32_lowering(prng, platform, keys, data):
operand_layouts=[layout] * 4,
result_layouts=[layout] * 2)
cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cuda")
cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu")
rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip")

View File

@ -27,14 +27,14 @@ from jaxlib import xla_client
from .mhlo_helpers import custom_call
try:
from .cuda import _cublas
from .cuda import _blas as _cublas
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
_cublas = None
try:
from .cuda import _cusolver
from .cuda import _solver as _cusolver
for _name, _value in _cusolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
@ -42,14 +42,14 @@ except ImportError:
try:
from .rocm import _hipblas
from .rocm import _blas as _hipblas
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
_hipblas = None
try:
from .rocm import _hipsolver
from .rocm import _solver as _hipsolver
for _name, _value in _hipsolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:

View File

@ -26,7 +26,7 @@ from jaxlib import xla_client
from .mhlo_helpers import custom_call
try:
from .cuda import _cusparse
from .cuda import _sparse as _cusparse
except ImportError:
_cusparse = None
else:
@ -34,7 +34,7 @@ else:
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
try:
from .rocm import _hipsparse
from .rocm import _sparse as _hipsparse
except ImportError:
_hipsparse = None
else:
@ -42,8 +42,8 @@ else:
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
cuda_is_supported : bool = _cusparse and _cusparse.cusparse_supported
rocm_is_supported : bool = _hipsparse and _hipsparse.hipsparse_supported
cuda_is_supported : bool = _cusparse and _cusparse.sparse_supported
rocm_is_supported : bool = _hipsparse and _hipsparse.sparse_supported
def _validate_csr_mhlo(data, indices, indptr, shape):

View File

@ -25,16 +25,27 @@ licenses(["notice"])
package(default_visibility = ["//:__subpackages__"])
cc_library(
name = "hip_vendor",
hdrs = [
"//jaxlib/gpu:vendor.h",
],
defines = ["JAX_GPU_HIP=1"],
deps = [
"@local_config_rocm//rocm:rocm_headers",
],
)
cc_library(
name = "hip_gpu_kernel_helpers",
srcs = if_rocm_is_configured(["hip_gpu_kernel_helpers.cc"]),
hdrs = if_rocm_is_configured(["hip_gpu_kernel_helpers.h"]),
srcs = if_rocm_is_configured(["//jaxlib/gpu:gpu_kernel_helpers.cc"]),
hdrs = if_rocm_is_configured(["//jaxlib/gpu:gpu_kernel_helpers.h"]),
copts = [
"-fexceptions",
],
features = ["-use_header_modules"],
deps = [
":hip_vendor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -46,11 +57,12 @@ cc_library(
cc_library(
name = "hipblas_kernels",
srcs = ["hipblas_kernels.cc"],
hdrs = ["hipblas_kernels.h"],
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
deps = [
"//jaxlib:handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
@ -68,15 +80,16 @@ cc_library(
)
pybind_extension(
name = "_hipblas",
srcs = ["hipblas.cc"],
name = "_blas",
srcs = ["//jaxlib/gpu:blas.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hipblas",
module_name = "_blas",
deps = [
":hip_vendor",
":hipblas_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@com_google_absl//absl/container:flat_hash_map",
@ -89,11 +102,12 @@ pybind_extension(
cc_library(
name = "hipsolver_kernels",
srcs = ["hipsolver_kernels.cc"],
hdrs = ["hipsolver_kernels.h"],
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
deps = [
"//jaxlib:handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -105,16 +119,17 @@ cc_library(
)
pybind_extension(
name = "_hipsolver",
srcs = ["hipsolver.cc"],
name = "_solver",
srcs = ["//jaxlib/gpu:solver.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hipsolver",
module_name = "_solver",
deps = [
":hip_gpu_kernel_helpers",
":hip_vendor",
":hipsolver_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@com_google_absl//absl/container:flat_hash_map",
@ -127,11 +142,12 @@ pybind_extension(
cc_library(
name = "hipsparse_kernels",
srcs = ["hipsparse_kernels.cc"],
hdrs = ["hipsparse_kernels.h"],
srcs = ["//jaxlib/gpu:sparse_kernels.cc"],
hdrs = ["//jaxlib/gpu:sparse_kernels.h"],
deps = [
"//jaxlib:handle_pool",
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -143,16 +159,17 @@ cc_library(
)
pybind_extension(
name = "_hipsparse",
srcs = ["hipsparse.cc"],
name = "_sparse",
srcs = ["//jaxlib/gpu:sparse.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hipsparse",
module_name = "_sparse",
deps = [
":hip_gpu_kernel_helpers",
":hip_vendor",
":hipsparse_kernels",
"//jaxlib:kernel_pybind11_helpers",
"@com_google_absl//absl/algorithm:container",
@ -172,13 +189,12 @@ pybind_extension(
cc_library(
name = "hip_lu_pivot_kernels",
srcs = [
"hip_lu_pivot_kernels.cc",
],
hdrs = ["hip_lu_pivot_kernels.h"],
srcs = ["//jaxlib/gpu:lu_pivot_kernels.cc"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_lu_pivot_kernels_impl",
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
@ -187,12 +203,11 @@ cc_library(
rocm_library(
name = "hip_lu_pivot_kernels_impl",
srcs = [
"hip_lu_pivot_kernels.hip.cc",
],
hdrs = ["hip_lu_pivot_kernels.h"],
srcs = ["//jaxlib/gpu:lu_pivot_kernels.cu.cc"],
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
@ -200,18 +215,19 @@ rocm_library(
)
pybind_extension(
name = "_hip_linalg",
srcs = ["hip_linalg.cc"],
name = "_linalg",
srcs = ["//jaxlib/gpu:linalg.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hip_linalg",
module_name = "_linalg",
deps = [
":hip_gpu_kernel_helpers",
":hip_lu_pivot_kernels",
":hip_lu_pivot_kernels_impl",
":hip_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@pybind11",
@ -220,13 +236,12 @@ pybind_extension(
cc_library(
name = "hip_prng_kernels",
srcs = [
"hip_prng_kernels.cc",
],
hdrs = ["hip_prng_kernels.h"],
srcs = ["//jaxlib/gpu:prng_kernels.cc"],
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_prng_kernels_impl",
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
@ -235,12 +250,11 @@ cc_library(
rocm_library(
name = "hip_prng_kernels_impl",
srcs = [
"hip_prng_kernels.hip.cc",
],
hdrs = ["hip_prng_kernels.h"],
srcs = ["//jaxlib/gpu:prng_kernels.cu.cc"],
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_vendor",
"//jaxlib:kernel_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
@ -248,17 +262,18 @@ rocm_library(
)
pybind_extension(
name = "_hip_prng",
srcs = ["hip_prng.cc"],
name = "_prng",
srcs = ["//jaxlib/gpu:prng.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hip_prng",
module_name = "_prng",
deps = [
":hip_gpu_kernel_helpers",
":hip_prng_kernels",
":hip_vendor",
"//jaxlib:kernel_pybind11_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@pybind11",
@ -268,11 +283,10 @@ pybind_extension(
py_library(
name = "rocm_gpu_support",
deps = [
":_hip_linalg",
":_hip_prng",
":_hipblas",
":_hipsolver",
":_hipsparse",
":_blas",
":_linalg",
":_prng",
":_solver",
":_sparse",
],
)

View File

@ -1,66 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIP_GPU_KERNEL_HELPERS_H_
#define JAXLIB_HIP_GPU_KERNEL_HELPERS_H_
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "rocm/include/hipsolver.h"
#include "rocm/include/hipsparse.h"
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
#define JAX_THROW_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) \
throw std::runtime_error(std::string(s___.message())); \
}
#define JAX_RETURN_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) \
return s___; \
}
namespace jax {
// Used via JAX_AS_STATUS(expr) macro.
absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line,
const char* expr);
absl::Status AsStatus(hipsolverStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(hipsparseStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(hipblasStatus_t status, const char* file,
std::int64_t line, const char* expr);
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
// completes.
absl::StatusOr<std::unique_ptr<void*[]>>
MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size);
} // namespace jax
#endif // JAXLIB_HIP_GPU_KERNEL_HELPERS_H_

View File

@ -1,51 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/rocm/hip_lu_pivot_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
namespace jax {
namespace {
std::string
BuildHipLuPivotsToPermutationDescriptor(std::int64_t batch_size,
std::int32_t pivot_size,
std::int32_t permutation_size) {
return PackDescriptorAsString(LuPivotsToPermutationDescriptor{
batch_size, pivot_size, permutation_size});
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["hip_lu_pivots_to_permutation"] =
EncapsulateFunction(HipLuPivotsToPermutation);
return dict;
}
PYBIND11_MODULE(_hip_linalg, m) {
m.def("registrations", &Registrations);
m.def("lu_pivots_to_permutation_descriptor",
[](std::int64_t batch_size, std::int32_t pivot_size,
std::int32_t permutation_size) {
std::string result = BuildHipLuPivotsToPermutationDescriptor(
batch_size, pivot_size, permutation_size);
return pybind11::bytes(result);
});
}
} // namespace
} // namespace jax

View File

@ -1,43 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hip_prng_kernels.h"
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
namespace {
std::string BuildHipThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["hip_threefry2x32"] = EncapsulateFunction(HipThreeFry2x32);
return dict;
}
PYBIND11_MODULE(_hip_prng, m) {
m.def("registrations", &Registrations);
m.def("threefry2x32_descriptor", [](std::int64_t n) {
std::string result = BuildHipThreeFry2x32Descriptor(n);
return pybind11::bytes(result);
});
}
} // namespace
} // namespace jax

View File

@ -1,47 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hip_prng_kernels.h"
#include <string_view>
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace {
absl::Status HipThreeFry2x32_(hipStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
LaunchThreeFry2x32Kernel(stream, buffers, **s);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError()));
return absl::OkStatus();
}
} // namespace
void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = HipThreeFry2x32_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
std::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
} // namespace jax

View File

@ -1,39 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIP_PRNG_KERNELS_H_
#define JAXLIB_HIP_PRNG_KERNELS_H_
#include <cstddef>
#include <string>
#include "rocm/include/hip/hip_runtime_api.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
struct ThreeFry2x32Descriptor {
std::int64_t n;
};
void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor);
void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIP_PRNG_KERNELS_H_

View File

@ -1,116 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hip_prng_kernels.h"
#include <array>
#include <cstddef>
namespace jax {
namespace {
__global__ void
ThreeFry2x32Kernel(const std::uint32_t* key0, const std::uint32_t* key1,
const std::uint32_t* data0, const std::uint32_t* data1,
std::uint32_t* out0, std::uint32_t* out1, std::int64_t n) {
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n;
idx += blockDim.x * gridDim.x) {
// Rotation distances specified by the Threefry2x32 algorithm.
std::uint32_t rotations[8] = {13, 15, 26, 6, 17, 29, 16, 24};
std::uint32_t x[2];
std::uint32_t ks[3];
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
ks[2] = 0x1BD11BDA;
ks[0] = key0[idx];
x[0] = data0[idx];
ks[2] = ks[2] ^ key0[idx];
ks[1] = key1[idx];
x[1] = data1[idx];
ks[2] = ks[2] ^ key1[idx];
auto rotate_left = [](std::uint32_t v, std::uint32_t distance) {
return (v << distance) | (v >> (32 - distance));
};
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
auto round = [&](std::uint32_t* v, std::uint32_t rotation) {
v[0] += v[1];
v[1] = rotate_left(v[1], rotation);
v[1] ^= v[0];
};
// There are no known statistical flaws with 13 rounds of Threefry2x32.
// We are conservative and use 20 rounds.
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1];
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + 1u;
for (int i = 4; i < 8; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[2];
x[1] = x[1] + ks[0] + 2u;
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1] + 3u;
for (int i = 4; i < 8; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + 4u;
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
out0[idx] = x[0] + ks[2];
out1[idx] = x[1] + ks[0] + 5u;
}
}
} // namespace
void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor) {
std::array<const std::uint32_t*, 2> keys;
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
std::array<const std::uint32_t*, 2> data;
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
std::array<std::uint32_t*, 2> out;
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
const int block_dim = 128;
const std::int64_t grid_dim =
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
out[1], descriptor.n);
}
} // namespace jax

View File

@ -1,57 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIPBLAS_KERNELS_H_
#define JAXLIB_HIPBLAS_KERNELS_H_
#include <cstddef>
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
// Set of types known to Hipsolver.
enum class HipblasType {
F32,
F64,
C64,
C128,
};
// Batched LU decomposition: getrfbatched
struct GetrfBatchedDescriptor {
HipblasType type;
int batch, n;
};
void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Batched QR decomposition: geqrfbatched
struct GeqrfBatchedDescriptor {
HipblasType type;
int batch, m, n;
};
void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIPBLAS_KERNELS_H_

View File

@ -1,435 +0,0 @@
/* Copyright 2019 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/rocm/hipsolver_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipsolver.h"
namespace jax {
namespace {
namespace py = pybind11;
// Converts a NumPy dtype to a Type.
HipsolverType DtypeToHipsolverType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, HipsolverType>({
{{'f', 4}, HipsolverType::F32},
{{'f', 8}, HipsolverType::F64},
{{'c', 8}, HipsolverType::C64},
{{'c', 16}, HipsolverType::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
}
return it->second;
}
// potrf: Cholesky decomposition
// Returns the workspace size and a descriptor for a potrf operation.
std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
std::int64_t workspace_size;
hipsolverFillMode_t uplo =
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
if (b == 1) {
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(float);
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(double);
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(hipComplex);
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(hipDoubleComplex);
break;
}
} else {
// TODO(rocm): when cuda and hip had same API for batched potrf, remove this
// batched potrf has different API compared to CUDA. In hip we still need to create the workspace and additional space to copy the batch array pointers
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(float)) + (b * sizeof(float*));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(double)) + (b * sizeof(double*));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(hipComplex)) + (b * sizeof(hipComplex*));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(hipDoubleComplex)) + (b * sizeof(hipDoubleComplex*));
break;
}
}
return {workspace_size,
PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})};
}
// getrf: LU decomposition
// Returns the workspace size and a descriptor for a getrf operation.
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
}
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})};
}
// geqrf: QR decomposition
// Returns the workspace size and a descriptor for a geqrf operation.
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
}
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})};
}
// orgqr/ungqr: apply elementary Householder transformations
// Returns the workspace size and a descriptor for a geqrf operation.
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
int m, int n, int k) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
}
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})};
}
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
// Returns the workspace size and a descriptor for a syevd operation.
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
hipsolverFillMode_t uplo =
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverSsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
/*lda=*/n, /*W=*/nullptr, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverDsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
/*lda=*/n, /*W=*/nullptr, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverCheevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
/*lda=*/n, /*W=*/nullptr, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork)));
break;
}
return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})};
}
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
// Returns the workspace size and a descriptor for a syevj_batched operation.
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
bool lower, int batch, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
hipsolverSyevjInfo_t params;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(&params)));
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
hipsolverFillMode_t uplo =
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
if (batch == 1) {
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
}
} else {
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
}
}
return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})};
}
// Singular value decomposition using QR algorithm: gesvd
// Returns the workspace size and a descriptor for a gesvd operation.
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
int m, int n, bool compute_uv,
bool full_matrices) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
signed char jobu, jobvt;
if (compute_uv) {
if (full_matrices) {
jobu = jobvt = 'A';
} else {
jobu = jobvt = 'S';
}
} else {
jobu = jobvt = 'N';
}
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverSgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverDgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverCgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverZgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
}
return {lwork,
PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})};
}
py::dict Registrations() {
py::dict dict;
dict["hipsolver_potrf"] = EncapsulateFunction(Potrf);
dict["hipsolver_getrf"] = EncapsulateFunction(Getrf);
dict["hipsolver_geqrf"] = EncapsulateFunction(Geqrf);
dict["hipsolver_orgqr"] = EncapsulateFunction(Orgqr);
dict["hipsolver_syevd"] = EncapsulateFunction(Syevd);
dict["hipsolver_syevj"] = EncapsulateFunction(Syevj);
dict["hipsolver_gesvd"] = EncapsulateFunction(Gesvd);
// dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); not supported by
// ROCm yet
return dict;
}
PYBIND11_MODULE(_hipsolver, m) {
m.def("registrations", &Registrations);
m.def("build_potrf_descriptor", &BuildPotrfDescriptor);
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
// m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); not supported by
// ROCm yet
}
} // namespace
} // namespace jax

View File

@ -1,721 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm/hipsolver_kernels.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipsolver.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
template <>
/*static*/ absl::StatusOr<SolverHandlePool::Handle>
SolverHandlePool::Borrow(hipStream_t stream) {
SolverHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
hipsolverHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
static int SizeOfHipsolverType(HipsolverType type) {
switch (type) {
case HipsolverType::F32:
return sizeof(float);
case HipsolverType::F64:
return sizeof(double);
case HipsolverType::C64:
return sizeof(hipFloatComplex);
case HipsolverType::C128:
return sizeof(hipDoubleComplex);
}
}
// potrf: Cholesky decomposition
static absl::Status Potrf_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const PotrfDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[2]);
void* workspace = buffers[3];
if (d.batch == 1) {
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSpotrf(handle.get(), d.uplo, d.n, a, d.n,
static_cast<float*>(workspace), d.lwork, info)));
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDpotrf(handle.get(), d.uplo, d.n, a, d.n,
static_cast<double*>(workspace), d.lwork, info)));
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrf(
handle.get(), d.uplo, d.n, a, d.n,
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrf(
handle.get(), d.uplo, d.n, a, d.n,
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
break;
}
}
} else {
auto buffer_ptrs_host =
MakeBatchPointers(stream, buffers[1], workspace, d.batch,
SizeOfHipsolverType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(buffer_ptrs_host.status());
// Make sure that accesses to buffer_ptrs_host complete before we delete it.
// TODO(phawkins): avoid synchronization here.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
switch (d.type) {
case HipsolverType::F32: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<float**>(workspace), d.n,
reinterpret_cast<float*>(static_cast<float**>(workspace) + d.batch),
d.lwork, info, d.batch)));
break;
}
case HipsolverType::F64: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<double**>(workspace), d.n,
reinterpret_cast<double*>(static_cast<double**>(workspace) + d.batch),
d.lwork, info, d.batch)));
break;
}
case HipsolverType::C64: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<hipFloatComplex**>(workspace), d.n,
reinterpret_cast<hipFloatComplex*>(static_cast<hipFloatComplex**>(workspace) +
d.batch), d.lwork, info, d.batch)));
break;
}
case HipsolverType::C128: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<hipDoubleComplex**>(workspace), d.n,
reinterpret_cast<hipDoubleComplex*>(static_cast<hipDoubleComplex**>(workspace) +
d.batch), d.lwork, info, d.batch)));
break;
}
}
}
return absl::OkStatus();
}
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Potrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// getrf: LU decomposition
static absl::Status Getrf_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GetrfDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
void* workspace = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<float*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<double*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverCgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<hipFloatComplex*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgetrf(
handle.get(), d.m, d.n, a, d.m,
static_cast<hipDoubleComplex*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
}
return absl::OkStatus();
}
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Getrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// geqrf: QR decomposition
static absl::Status Geqrf_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GeqrfDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
void* workspace = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* tau = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSgeqrf(handle.get(), d.m, d.n, a, d.m, tau,
static_cast<float*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* tau = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDgeqrf(handle.get(), d.m, d.n, a, d.m, tau,
static_cast<double*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
hipFloatComplex* tau = static_cast<hipFloatComplex*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgeqrf(
handle.get(), d.m, d.n, a, d.m, tau,
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
hipDoubleComplex* tau = static_cast<hipDoubleComplex*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgeqrf(
handle.get(), d.m, d.n, a, d.m, tau,
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
}
return absl::OkStatus();
}
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// orgqr/ungqr: apply elementary Householder transformations
static absl::Status Orgqr_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const OrgqrDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[2], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
void* workspace = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[2]);
float* tau = static_cast<float*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<float*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[2]);
double* tau = static_cast<double*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<double*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[2]);
hipFloatComplex* tau = static_cast<hipFloatComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCungqr(
handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[2]);
hipDoubleComplex* tau = static_cast<hipDoubleComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZungqr(
handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
}
return absl::OkStatus();
}
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
static absl::Status Syevd_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SyevdDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SyevdDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<float*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<double*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipFloatComplex*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipDoubleComplex*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
}
return absl::OkStatus();
}
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Syevd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<SyevjDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SyevjDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
hipsolverSyevjInfo_t params;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(&params)));
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
if (d.batch == 1) {
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<float*>(work), d.lwork, info, params)));
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<double*>(work), d.lwork, info, params)));
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipFloatComplex*>(work), d.lwork, info, params)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipDoubleComplex*>(work), d.lwork, info, params)));
break;
}
}
} else {
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<float*>(work), d.lwork, info, params, d.batch)));
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<double*>(work), d.lwork, info, params, d.batch)));
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipFloatComplex*>(work), d.lwork, info, params, d.batch)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipDoubleComplex*>(work),
d.lwork, info, params, d.batch)));
break;
}
}
}
return absl::OkStatus();
}
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Syevj_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Singular value decomposition using QR algorithm: gesvd
static absl::Status Gesvd_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GesvdDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
float* u = static_cast<float*>(buffers[3]);
float* vt = static_cast<float*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s,
u, d.m, vt, d.n, static_cast<float*>(work), d.lwork,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
double* u = static_cast<double*>(buffers[3]);
double* vt = static_cast<double*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<double*>(work), d.lwork,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
hipFloatComplex* u = static_cast<hipFloatComplex*>(buffers[3]);
hipFloatComplex* vt = static_cast<hipFloatComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<hipFloatComplex*>(work), d.lwork, /*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
hipDoubleComplex* u = static_cast<hipDoubleComplex*>(buffers[3]);
hipDoubleComplex* vt = static_cast<hipDoubleComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<hipDoubleComplex*>(work), d.lwork,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
}
return absl::OkStatus();
}
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// TODO(rocm): add Gesvdj_ apis when support from hipsolver is ready
} // namespace jax

View File

@ -1,122 +0,0 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIPSOLVER_KERNELS_H_
#define JAXLIB_HIPSOLVER_KERNELS_H_
#include "absl/status/statusor.h"
#include "jaxlib/handle_pool.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "rocm/include/hipsolver.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
using SolverHandlePool = HandlePool<hipsolverHandle_t, hipStream_t>;
template <>
absl::StatusOr<SolverHandlePool::Handle>
SolverHandlePool::Borrow(hipStream_t stream);
// Set of types known to Hipsolver.
enum class HipsolverType {
F32,
F64,
C64,
C128,
};
// potrf: Cholesky decomposition
struct PotrfDescriptor {
HipsolverType type;
hipsolverFillMode_t uplo;
std::int64_t batch, n;
int lwork;
};
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// getrf: LU decomposition
struct GetrfDescriptor {
HipsolverType type;
int batch, m, n, lwork;
};
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// geqrf: QR decomposition
struct GeqrfDescriptor {
HipsolverType type;
int batch, m, n, lwork;
};
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// orgqr/ungqr: apply elementary Householder transformations
struct OrgqrDescriptor {
HipsolverType type;
int batch, m, n, k, lwork;
};
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
struct SyevdDescriptor {
HipsolverType type;
hipsolverFillMode_t uplo;
int batch, n;
int lwork;
};
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
struct SyevjDescriptor {
HipsolverType type;
hipsolverFillMode_t uplo;
int batch, n;
int lwork;
};
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Singular value decomposition using QR algorithm: gesvd
struct GesvdDescriptor {
HipsolverType type;
int batch, m, n;
int lwork;
signed char jobu, jobvt;
};
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIPSOLVER_KERNELS_H_