mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything. PiperOrigin-RevId: 483666051
This commit is contained in:
parent
621f06660d
commit
a852710a09
@ -184,25 +184,25 @@ def prepare_wheel(sources_path):
|
||||
copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir)
|
||||
|
||||
cuda_dir = os.path.join(jaxlib_dir, "cuda")
|
||||
if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"):
|
||||
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"):
|
||||
libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice")
|
||||
os.makedirs(libdevice_dir)
|
||||
copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cusolver.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cublas.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cuda_linalg.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_cuda_prng.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir)
|
||||
copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir)
|
||||
rocm_dir = os.path.join(jaxlib_dir, "rocm")
|
||||
if exists(f"__main__/jaxlib/rocm/_hipsolver.{pyext}"):
|
||||
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"):
|
||||
os.makedirs(rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hipsolver.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hipblas.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hip_linalg.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_hip_prng.{pyext}", dst_dir=rocm_dir)
|
||||
if exists(f"__main__/jaxlib/cuda/_cusparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/cuda/_cusparse.{pyext}", dst_dir=cuda_dir)
|
||||
if exists(f"__main__/jaxlib/rocm/_hipsparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/rocm/_hipsparse.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir)
|
||||
copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir)
|
||||
if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir)
|
||||
if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"):
|
||||
copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir)
|
||||
|
||||
|
||||
mlir_dir = os.path.join(jaxlib_dir, "mlir")
|
||||
|
@ -24,15 +24,31 @@ licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
|
||||
cc_library(
|
||||
name = "cuda_vendor",
|
||||
hdrs = [
|
||||
"//jaxlib/gpu:vendor.h",
|
||||
],
|
||||
defines = ["JAX_GPU_CUDA=1"],
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cuda_gpu_kernel_helpers",
|
||||
srcs = ["cuda_gpu_kernel_helpers.cc"],
|
||||
hdrs = ["cuda_gpu_kernel_helpers.h"],
|
||||
srcs = [
|
||||
"//jaxlib/gpu:gpu_kernel_helpers.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//jaxlib/gpu:gpu_kernel_helpers.h",
|
||||
],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":cuda_vendor",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -47,10 +63,11 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "cublas_kernels",
|
||||
srcs = ["cublas_kernels.cc"],
|
||||
hdrs = ["cublas_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
@ -71,31 +88,32 @@ cc_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_cublas",
|
||||
srcs = ["cublas.cc"],
|
||||
name = "_blas",
|
||||
srcs = ["//jaxlib/gpu:blas.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_cublas",
|
||||
module_name = "_blas",
|
||||
deps = [
|
||||
":cublas_kernels",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_kernels",
|
||||
srcs = ["cusolver_kernels.cc"],
|
||||
hdrs = ["cusolver_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
@ -108,16 +126,17 @@ cc_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_cusolver",
|
||||
srcs = ["cusolver.cc"],
|
||||
name = "_solver",
|
||||
srcs = ["//jaxlib/gpu:solver.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_cusolver",
|
||||
module_name = "_solver",
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
":cusolver_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
@ -131,10 +150,11 @@ pybind_extension(
|
||||
|
||||
cc_library(
|
||||
name = "cusparse_kernels",
|
||||
srcs = ["cusparse_kernels.cc"],
|
||||
hdrs = ["cusparse_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:sparse_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:sparse_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
@ -148,16 +168,17 @@ cc_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_cusparse",
|
||||
srcs = ["cusparse.cc"],
|
||||
name = "_sparse",
|
||||
srcs = ["//jaxlib/gpu:sparse.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_cusparse",
|
||||
module_name = "_sparse",
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
":cusparse_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
@ -179,12 +200,13 @@ pybind_extension(
|
||||
cc_library(
|
||||
name = "cuda_lu_pivot_kernels",
|
||||
srcs = [
|
||||
"cuda_lu_pivot_kernels.cc",
|
||||
"//jaxlib/gpu:lu_pivot_kernels.cc",
|
||||
],
|
||||
hdrs = ["cuda_lu_pivot_kernels.h"],
|
||||
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_lu_pivot_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
@ -194,11 +216,12 @@ cc_library(
|
||||
cuda_library(
|
||||
name = "cuda_lu_pivot_kernels_impl",
|
||||
srcs = [
|
||||
"cuda_lu_pivot_kernels.cu.cc",
|
||||
"//jaxlib/gpu:lu_pivot_kernels.cu.cc",
|
||||
],
|
||||
hdrs = ["cuda_lu_pivot_kernels.h"],
|
||||
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
@ -206,18 +229,19 @@ cuda_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_cuda_linalg",
|
||||
srcs = ["cuda_linalg.cc"],
|
||||
name = "_linalg",
|
||||
srcs = ["//jaxlib/gpu:linalg.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_cuda_linalg",
|
||||
module_name = "_linalg",
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_lu_pivot_kernels",
|
||||
":cuda_lu_pivot_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
@ -228,12 +252,13 @@ pybind_extension(
|
||||
cc_library(
|
||||
name = "cuda_prng_kernels",
|
||||
srcs = [
|
||||
"cuda_prng_kernels.cc",
|
||||
"//jaxlib/gpu:prng_kernels.cc",
|
||||
],
|
||||
hdrs = ["cuda_prng_kernels.h"],
|
||||
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_prng_kernels_impl",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
@ -243,11 +268,12 @@ cc_library(
|
||||
cuda_library(
|
||||
name = "cuda_prng_kernels_impl",
|
||||
srcs = [
|
||||
"cuda_prng_kernels.cu.cc",
|
||||
"//jaxlib/gpu:prng_kernels.cu.cc",
|
||||
],
|
||||
hdrs = ["cuda_prng_kernels.h"],
|
||||
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
@ -255,14 +281,14 @@ cuda_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_cuda_prng",
|
||||
srcs = ["cuda_prng.cc"],
|
||||
name = "_prng",
|
||||
srcs = ["//jaxlib/gpu:prng.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_cuda_prng",
|
||||
module_name = "_prng",
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_prng_kernels",
|
||||
@ -275,12 +301,13 @@ pybind_extension(
|
||||
|
||||
cc_library(
|
||||
name = "cuda_gpu_kernels",
|
||||
srcs = ["cuda_gpu_kernels.cc"],
|
||||
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cublas_kernels",
|
||||
":cuda_lu_pivot_kernels",
|
||||
":cuda_prng_kernels",
|
||||
":cuda_vendor",
|
||||
":cusolver_kernels",
|
||||
":cusparse_kernels",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
@ -291,10 +318,10 @@ cc_library(
|
||||
py_library(
|
||||
name = "cuda_gpu_support",
|
||||
deps = [
|
||||
":_cublas",
|
||||
":_cuda_linalg",
|
||||
":_cuda_prng",
|
||||
":_cusolver",
|
||||
":_cusparse",
|
||||
":_blas",
|
||||
":_linalg",
|
||||
":_prng",
|
||||
":_solver",
|
||||
":_sparse",
|
||||
],
|
||||
)
|
||||
|
@ -1,84 +0,0 @@
|
||||
/* Copyright 2019 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "jaxlib/cuda/cublas_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
// Converts a NumPy dtype to a Type.
|
||||
CublasType DtypeToCublasType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, CublasType>({
|
||||
{{'f', 4}, CublasType::F32},
|
||||
{{'f', 8}, CublasType::F64},
|
||||
{{'c', 8}, CublasType::C64},
|
||||
{{'c', 16}, CublasType::C128},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Returns the descriptor for a GetrfBatched operation.
|
||||
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
|
||||
int b, int n) {
|
||||
CublasType type = DtypeToCublasType(dtype);
|
||||
size_t size = b * sizeof(void*);
|
||||
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a GetrfBatched operation.
|
||||
std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
|
||||
int b, int m, int n) {
|
||||
CublasType type = DtypeToCublasType(dtype);
|
||||
size_t size = b * sizeof(void*);
|
||||
return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})};
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
|
||||
dict["cublas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_cublas, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
|
||||
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
@ -1,225 +0,0 @@
|
||||
/* Copyright 2019 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cublas_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
using BlasHandlePool = HandlePool<cublasHandle_t, cudaStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
|
||||
cudaStream_t stream) {
|
||||
BlasHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cublasHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts a NumPy dtype to a CublasType.
|
||||
|
||||
int SizeOfCublasType(CublasType type) {
|
||||
switch (type) {
|
||||
case CublasType::F32:
|
||||
return sizeof(float);
|
||||
case CublasType::F64:
|
||||
return sizeof(double);
|
||||
case CublasType::C64:
|
||||
return sizeof(cuComplex);
|
||||
case CublasType::C128:
|
||||
return sizeof(cuDoubleComplex);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GetrfBatchedDescriptor& d = **s;
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.n * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
|
||||
SizeOfCublasType(d.type) * d.n * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case CublasType::F32: {
|
||||
float** batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::F64: {
|
||||
double** batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::C64: {
|
||||
cuComplex** batch_ptrs = static_cast<cuComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::C128: {
|
||||
cuDoubleComplex** batch_ptrs = static_cast<cuDoubleComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Batched QR decomposition: geqrfbatched
|
||||
|
||||
static absl::Status GeqrfBatched_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GeqrfBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GeqrfBatchedDescriptor& d = **s;
|
||||
auto h = BlasHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
std::vector<int> info(d.batch);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
|
||||
SizeOfCublasType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
auto tau_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfCublasType(d.type) * std::min(d.m, d.n));
|
||||
JAX_RETURN_IF_ERROR(tau_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case CublasType::F32: {
|
||||
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
float** tau_batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::F64: {
|
||||
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
double** tau_batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::C64: {
|
||||
cuComplex** a_batch_ptrs = static_cast<cuComplex**>(buffers[3]);
|
||||
cuComplex** tau_batch_ptrs = static_cast<cuComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
case CublasType::C128: {
|
||||
cuDoubleComplex** a_batch_ptrs =
|
||||
static_cast<cuDoubleComplex**>(buffers[3]);
|
||||
cuDoubleComplex** tau_batch_ptrs =
|
||||
static_cast<cuDoubleComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto it =
|
||||
std::find_if(info.begin(), info.end(), [](int i) { return i != 0; });
|
||||
|
||||
if (it != info.end()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("QR decomposition failed with status %d for batch "
|
||||
"element %d",
|
||||
*it, std::distance(info.begin(), it)));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -1,140 +0,0 @@
|
||||
/* Copyright 2019 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
std::string ErrorString(cudaError_t error) { return cudaGetErrorString(error); }
|
||||
|
||||
std::string ErrorString(cusparseStatus_t status) {
|
||||
return cusparseGetErrorString(status);
|
||||
}
|
||||
|
||||
std::string ErrorString(cusolverStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSOLVER_STATUS_SUCCESS:
|
||||
return "cuSolver success.";
|
||||
case CUSOLVER_STATUS_NOT_INITIALIZED:
|
||||
return "cuSolver has not been initialized";
|
||||
case CUSOLVER_STATUS_ALLOC_FAILED:
|
||||
return "cuSolver allocation failed";
|
||||
case CUSOLVER_STATUS_INVALID_VALUE:
|
||||
return "cuSolver invalid value error";
|
||||
case CUSOLVER_STATUS_ARCH_MISMATCH:
|
||||
return "cuSolver architecture mismatch error";
|
||||
case CUSOLVER_STATUS_MAPPING_ERROR:
|
||||
return "cuSolver mapping error";
|
||||
case CUSOLVER_STATUS_EXECUTION_FAILED:
|
||||
return "cuSolver execution failed";
|
||||
case CUSOLVER_STATUS_INTERNAL_ERROR:
|
||||
return "cuSolver internal error";
|
||||
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
return "cuSolver matrix type not supported error";
|
||||
case CUSOLVER_STATUS_NOT_SUPPORTED:
|
||||
return "cuSolver not supported error";
|
||||
case CUSOLVER_STATUS_ZERO_PIVOT:
|
||||
return "cuSolver zero pivot error";
|
||||
case CUSOLVER_STATUS_INVALID_LICENSE:
|
||||
return "cuSolver invalid license error";
|
||||
default:
|
||||
return absl::StrCat("Unknown cuSolver error: ", status);
|
||||
}
|
||||
}
|
||||
|
||||
std::string ErrorString(cublasStatus_t status) {
|
||||
switch (status) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "cuBlas success";
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "cuBlas has not been initialized";
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "cuBlas allocation failure";
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "cuBlas invalid value error";
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "cuBlas architecture mismatch";
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "cuBlas mapping error";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "cuBlas execution failed";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "cuBlas internal error";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "cuBlas not supported error";
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "cuBlas license error";
|
||||
default:
|
||||
return "Unknown cuBlas error";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string ErrorString(T status, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
return absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr,
|
||||
ErrorString(status));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
if (error != cudaSuccess)
|
||||
return absl::InternalError(ErrorString(error, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AsStatus(cusolverStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != CUSOLVER_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AsStatus(cusparseStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != CUSPARSE_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AsStatus(cublasStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<void* []>> MakeBatchPointers(
|
||||
cudaStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
char* ptr = static_cast<char*>(buffer);
|
||||
auto host_ptrs = absl::make_unique<void*[]>(batch);
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
host_ptrs[i] = ptr;
|
||||
ptr += batch_elem_size;
|
||||
}
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudaMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
|
||||
cudaMemcpyHostToDevice, stream)));
|
||||
return std::move(host_ptrs);
|
||||
}
|
||||
} // namespace jax
|
@ -1,51 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
absl::Status CudaLuPivotsToPermutation_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
auto s =
|
||||
UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
LaunchLuPivotsToPermutationKernel(stream, buffers, **s);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len,
|
||||
XlaCustomCallStatus* status) {
|
||||
auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
std::string_view message = s.message();
|
||||
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace jax
|
@ -1,77 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
__device__ void ComputePermutation(const std::int32_t* pivots,
|
||||
std::int32_t* permutation_out,
|
||||
const std::int32_t pivot_size,
|
||||
const std::int32_t permutation_size) {
|
||||
for (int i = 0; i < permutation_size; ++i) {
|
||||
permutation_out[i] = i;
|
||||
}
|
||||
|
||||
// Compute the permutation from a sequence of transpositions encoded in the
|
||||
// pivot array by applying the transpositions in order on the identity
|
||||
// permutation.
|
||||
for (int i = 0; i < pivot_size; ++i) {
|
||||
if ((pivots[i] < 0) || (pivots[i] >= permutation_size)) {
|
||||
continue;
|
||||
}
|
||||
std::int32_t swap_temporary = permutation_out[i];
|
||||
permutation_out[i] = permutation_out[pivots[i]];
|
||||
permutation_out[pivots[i]] = swap_temporary;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void LuPivotsToPermutationKernel(
|
||||
const std::int32_t* pivots, std::int32_t* permutation_out,
|
||||
const std::int64_t batch_size, const std::int32_t pivot_size,
|
||||
const std::int32_t permutation_size) {
|
||||
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
idx < batch_size; idx += blockDim.x * gridDim.x) {
|
||||
// Fill in the output array with the identity permutation.
|
||||
ComputePermutation(pivots + idx * pivot_size,
|
||||
permutation_out + idx * permutation_size, pivot_size,
|
||||
permutation_size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LaunchLuPivotsToPermutationKernel(
|
||||
cudaStream_t stream, void** buffers,
|
||||
LuPivotsToPermutationDescriptor descriptor) {
|
||||
const std::int32_t* pivots =
|
||||
reinterpret_cast<const std::int32_t*>(buffers[0]);
|
||||
std::int32_t* permutation_out = reinterpret_cast<std::int32_t*>(buffers[1]);
|
||||
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim = std::min<std::int64_t>(
|
||||
1024, (descriptor.batch_size + block_dim - 1) / block_dim);
|
||||
|
||||
LuPivotsToPermutationKernel<<<grid_dim, block_dim,
|
||||
/*dynamic_shared_mem_bytes=*/0, stream>>>(
|
||||
pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size,
|
||||
descriptor.permutation_size);
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -1,43 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_CUDA_LU_PIVOT_KERNELS_H_
|
||||
#define JAXLIB_CUDA_LU_PIVOT_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
struct LuPivotsToPermutationDescriptor {
|
||||
std::int64_t batch_size;
|
||||
std::int32_t pivot_size;
|
||||
std::int32_t permutation_size;
|
||||
};
|
||||
|
||||
void LaunchLuPivotsToPermutationKernel(
|
||||
cudaStream_t stream, void** buffers,
|
||||
LuPivotsToPermutationDescriptor descriptor);
|
||||
|
||||
void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len,
|
||||
XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CUDA_LU_PIVOT_KERNELS_H_
|
@ -1,618 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/cuda/cusparse_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
cusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, cusparseIndexType_t>({
|
||||
{{'u', 2}, CUSPARSE_INDEX_16U},
|
||||
{{'i', 4}, CUSPARSE_INDEX_32I},
|
||||
{{'i', 8}, CUSPARSE_INDEX_64I},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type)));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, cudaDataType>({
|
||||
{{'f', 2}, CUDA_R_16F}, {{'c', 4}, CUDA_C_16F}, {{'f', 4}, CUDA_R_32F},
|
||||
{{'c', 8}, CUDA_C_32F}, {{'f', 8}, CUDA_R_64F},
|
||||
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
|
||||
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
|
||||
{{'u', 4}, CUDA_R_32U},
|
||||
#if JAX_CUSPARSE_11300
|
||||
{{'V', 2}, CUDA_R_16BF},
|
||||
#endif
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type)));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
|
||||
const py::dtype& index_dtype,
|
||||
int rows, int cols, int nnz,
|
||||
int batch_count,
|
||||
int batch_stride) {
|
||||
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
cusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
|
||||
return SparseMatDescriptor{
|
||||
value_type, index_type, rows, cols, nnz, batch_count, batch_stride};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense matrix.
|
||||
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
|
||||
int rows, int cols, int batch_count,
|
||||
int batch_stride) {
|
||||
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense vector.
|
||||
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
||||
int size) {
|
||||
cudaDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseVecDescriptor{value_type, size};
|
||||
}
|
||||
|
||||
#if JAX_CUSPARSE_11300
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count*/1, /*batch_stride*/0);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
|
||||
// buffer_size does not reference these pointers, but does error on NULL.
|
||||
// TODO(jakevdp): check whether this is documented.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrFromDense: Convert dense matrix to CSR matrix
|
||||
|
||||
// Returns the descriptor for a CsrFromDense operation.
|
||||
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
absl::Status CsrFromDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrMatvec: Product of CSR matrix and dense vector.
|
||||
|
||||
// Returns the descriptor for a CsrMatvec operation.
|
||||
std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
DenseVecDescriptor x =
|
||||
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
|
||||
DenseVecDescriptor y =
|
||||
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(y.type);
|
||||
CudaConst beta = CudaZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
|
||||
// CsrMatmat: Product of CSR matrix and dense matrix.
|
||||
|
||||
// Returns the descriptor for a CsrMatmat operation.
|
||||
std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
DenseMatDescriptor B =
|
||||
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
DenseMatDescriptor C =
|
||||
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(C.type);
|
||||
CudaConst beta = CudaZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a CooToDense operation.
|
||||
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
// CooFromDense: Convert dense matrix to COO matrix
|
||||
|
||||
// Returns the descriptor for a CooFromDense operation.
|
||||
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty,
|
||||
d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
// CooMatvec: Product of COO matrix and dense vector.
|
||||
|
||||
// Returns the descriptor for a CooMatvec operation.
|
||||
std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
DenseVecDescriptor x =
|
||||
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
|
||||
DenseVecDescriptor y =
|
||||
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(y.type);
|
||||
CudaConst beta = CudaZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
CUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
|
||||
// CooMatmat: Product of COO matrix and dense matrix.
|
||||
|
||||
// Returns the descriptor for a CooMatmat operation.
|
||||
std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose, int batch_count,
|
||||
int lhs_batch_stride, int rhs_batch_stride) {
|
||||
// Three batch modes are supported, C_i = A_i B, C_i = A B_i, and
|
||||
// Ci = A_i B_i, where `i` denotes the batch dimension.
|
||||
// All three matrices A, B, and C must have the same batch count.
|
||||
// Use batch stride to trigger individual mode, e.g.,
|
||||
// `rhs_batch_stride = 0` for C_i = A_i B.
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
batch_count, lhs_batch_stride);
|
||||
DenseMatDescriptor B =
|
||||
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols,
|
||||
batch_count, rhs_batch_stride);
|
||||
int C_rows = (transpose == true) ? cols : rows;
|
||||
// TODO(tianjianlu): enable the selection of batch stride.
|
||||
// The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
|
||||
// in cusparse library does not allow batch_stride = 0.
|
||||
// int C_batch_stride = (batch_count > 1)? C_rows * BCcols : 0;
|
||||
int C_batch_stride = C_rows * BCcols;
|
||||
DenseMatDescriptor C =
|
||||
BuildDenseMatDescriptor(compute_dtype, /*rows=*/C_rows, /*cols=*/BCcols,
|
||||
batch_count, C_batch_stride);
|
||||
cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
|
||||
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCooSetStridedBatch(
|
||||
mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseDnMatSetStridedBatch(
|
||||
mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseDnMatSetStridedBatch(
|
||||
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
|
||||
size_t buffer_size;
|
||||
CudaConst alpha = CudaOne(C.type);
|
||||
CudaConst beta = CudaZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
#endif // if JAX_CUSPARSE_11300
|
||||
|
||||
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
|
||||
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
size_t size;
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
|
||||
/*du=*/nullptr, /*B=*/nullptr, ldb, &size)));
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t Gtsv2BufferSizeF32(int m, int n, int ldb) {
|
||||
return Gtsv2BufferSize(cusparseSgtsv2_bufferSizeExt, m, n, ldb);
|
||||
}
|
||||
|
||||
size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
|
||||
return Gtsv2BufferSize(cusparseDgtsv2_bufferSizeExt, m, n, ldb);
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
#if JAX_CUSPARSE_11300
|
||||
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
|
||||
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
|
||||
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
|
||||
dict["cusparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
|
||||
dict["cusparse_coo_todense"] = EncapsulateFunction(CooToDense);
|
||||
dict["cusparse_coo_fromdense"] = EncapsulateFunction(CooFromDense);
|
||||
dict["cusparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
|
||||
dict["cusparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
|
||||
#endif
|
||||
dict["cusparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
|
||||
dict["cusparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
|
||||
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_cusparse, m) {
|
||||
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11300);
|
||||
m.def("registrations", &Registrations);
|
||||
#if JAX_CUSPARSE_11300
|
||||
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
|
||||
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
|
||||
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
|
||||
m.def("build_csr_matmat_descriptor", &BuildCsrMatmatDescriptor);
|
||||
m.def("build_coo_todense_descriptor", &BuildCooToDenseDescriptor);
|
||||
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
|
||||
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
|
||||
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
|
||||
#endif
|
||||
m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32);
|
||||
m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64);
|
||||
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
@ -1,618 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cusparse_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#if JAX_CUDA_11080
|
||||
#include "third_party/gpus/cuda/include/cuda_fp8.h"
|
||||
#endif
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
// cuSPARSE generic APIs are not supported on Windows until 11.0
|
||||
// cusparseIndexType_t is used in very limited scope so manually define will
|
||||
// workaround compiling issue without harm.
|
||||
#if defined(_WIN32) && (CUSPARSE_VERSION < 11000)
|
||||
typedef enum {
|
||||
CUSPARSE_INDEX_16U = 1,
|
||||
CUSPARSE_INDEX_32I = 2,
|
||||
CUSPARSE_INDEX_64I = 3
|
||||
} cusparseIndexType_t;
|
||||
#endif
|
||||
|
||||
namespace jax {
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
|
||||
cudaStream_t stream) {
|
||||
SparseHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cusparseHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
CudaConst CudaZero(cudaDataType type) {
|
||||
CudaConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
return c;
|
||||
}
|
||||
|
||||
CudaConst CudaOne(cudaDataType type) {
|
||||
CudaConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
switch (type) {
|
||||
#if JAX_CUSPARSE_11300
|
||||
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
|
||||
case CUDA_R_4I:
|
||||
case CUDA_C_4I:
|
||||
#endif
|
||||
case CUDA_R_8I:
|
||||
case CUDA_C_8I:
|
||||
c.i8[0] = 1;
|
||||
break;
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_4U:
|
||||
case CUDA_C_4U:
|
||||
#endif
|
||||
case CUDA_R_8U:
|
||||
case CUDA_C_8U:
|
||||
c.u8[0] = 1;
|
||||
break;
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_16I:
|
||||
case CUDA_C_16I:
|
||||
c.i16[0] = 1;
|
||||
break;
|
||||
case CUDA_R_16U:
|
||||
case CUDA_C_16U:
|
||||
c.u16[0] = 1;
|
||||
break;
|
||||
#endif
|
||||
case CUDA_R_32I:
|
||||
case CUDA_C_32I:
|
||||
c.i32[0] = 1;
|
||||
break;
|
||||
case CUDA_R_32U:
|
||||
case CUDA_C_32U:
|
||||
c.u32[0] = 1;
|
||||
break;
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_64I:
|
||||
case CUDA_C_64I:
|
||||
c.i64[0] = 1;
|
||||
break;
|
||||
case CUDA_R_64U:
|
||||
case CUDA_C_64U:
|
||||
c.u64[0] = 1;
|
||||
break;
|
||||
#endif
|
||||
#if JAX_CUDA_11080
|
||||
case CUDA_R_8F_E4M3:
|
||||
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E4M3);
|
||||
break;
|
||||
case CUDA_R_8F_E5M2:
|
||||
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E5M2);
|
||||
break;
|
||||
#endif
|
||||
// TODO(jakevdp): 16F/16BF here might break on big endian platforms.
|
||||
case CUDA_R_16F:
|
||||
case CUDA_C_16F:
|
||||
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
|
||||
break;
|
||||
#if JAX_CUSPARSE_11300
|
||||
case CUDA_R_16BF:
|
||||
case CUDA_C_16BF:
|
||||
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
|
||||
break;
|
||||
#endif
|
||||
case CUDA_R_32F:
|
||||
case CUDA_C_32F:
|
||||
c.f32[0] = 1.0;
|
||||
break;
|
||||
case CUDA_R_64F:
|
||||
case CUDA_C_64F:
|
||||
c.f64[0] = 1.0;
|
||||
break;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
#if JAX_CUSPARSE_11300
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
static absl::Status CsrToDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrFromDense: Convert dense matrix to CSR matrix
|
||||
|
||||
static absl::Status CsrFromDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type, d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrMatvec: Product of CSR matrix and dense vector.
|
||||
|
||||
static absl::Status CsrMatvec_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CsrMatvecDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* csr_values = buffers[0];
|
||||
void* csr_col_ind = buffers[1];
|
||||
void* csr_row_offsets = buffers[2];
|
||||
void* xbuf = buffers[3];
|
||||
void* ybuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.y.type);
|
||||
CudaConst beta = CudaZero(d.y.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
|
||||
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CsrMatmat: Product of CSR matrix and dense matrix.
|
||||
|
||||
static absl::Status CsrMatmat_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CsrMatmatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* csr_values = buffers[0];
|
||||
void* csr_col_ind = buffers[1];
|
||||
void* csr_row_offsets = buffers[2];
|
||||
void* Bbuf = buffers[3];
|
||||
void* Cbuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.C.type);
|
||||
CudaConst beta = CudaZero(d.C.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets,
|
||||
csr_col_ind, csr_values, d.A.index_type, d.A.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
|
||||
static absl::Status CooToDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[1],
|
||||
/*cooColInd=*/buffers[2],
|
||||
/*cooValues=*/buffers[0], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooFromDense: Convert dense matrix to COO matrix
|
||||
|
||||
static absl::Status CooFromDense_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cusparseDnMatDescr_t mat_a = 0;
|
||||
cusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[2],
|
||||
/*cooColInd=*/buffers[3],
|
||||
/*cooValues=*/buffers[1], d.index_type,
|
||||
CUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooMatvec: Product of COO matrix and dense vector.
|
||||
|
||||
static absl::Status CooMatvec_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CooMatvecDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* coo_values = buffers[0];
|
||||
void* coo_row_ind = buffers[1];
|
||||
void* coo_col_ind = buffers[2];
|
||||
void* xbuf = buffers[3];
|
||||
void* ybuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.y.type);
|
||||
CudaConst beta = CudaZero(d.y.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnVecDescr_t vec_x = 0;
|
||||
cusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// CooMatmat: Product of COO matrix and dense matrix.
|
||||
|
||||
static absl::Status CooMatmat_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const CooMatmatDescriptor& d = **s;
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
void* coo_values = buffers[0];
|
||||
void* coo_row_ind = buffers[1];
|
||||
void* coo_col_ind = buffers[2];
|
||||
void* Bbuf = buffers[3];
|
||||
void* Cbuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
CudaConst alpha = CudaOne(d.C.type);
|
||||
CudaConst beta = CudaZero(d.C.type);
|
||||
|
||||
cusparseSpMatDescr_t mat_a = 0;
|
||||
cusparseDnMatDescr_t mat_b = 0;
|
||||
cusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
|
||||
/*batchStride=*/d.A.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
|
||||
/*batchStride=*/d.B.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
|
||||
&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
|
||||
/*batchStride=*/d.C.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
#endif // if JAX_CUSPARSE_11300
|
||||
|
||||
template <typename T, typename F>
|
||||
static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
auto s = UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const Gtsv2Descriptor& descriptor = **s;
|
||||
int m = descriptor.m;
|
||||
int n = descriptor.n;
|
||||
int ldb = descriptor.ldb;
|
||||
|
||||
const T* dl = (const T*)(buffers[0]);
|
||||
const T* d = (const T*)(buffers[1]);
|
||||
const T* du = (const T*)(buffers[2]);
|
||||
const T* B = (T*)(buffers[3]);
|
||||
T* X = (T*)(buffers[4]);
|
||||
void* buffer = buffers[5];
|
||||
|
||||
// The solution X is written in place to B. We need to therefore copy the
|
||||
// contents of B into the output buffer X and pass that into the kernel as B.
|
||||
// Once copy insertion is supported for custom call aliasing, we could alias B
|
||||
// with X and avoid the copy, the code below is written defensively assuming B
|
||||
// and X might alias, but today we know they will not.
|
||||
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
|
||||
if (X != B) {
|
||||
size_t B_bytes = ldb * n * sizeof(T);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -1,160 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_CUSPARSE_KERNELS_H_
|
||||
#define JAXLIB_CUSPARSE_KERNELS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_CUSPARSE_11300 (CUSPARSE_VERSION >= 11300)
|
||||
// CUDA-11.8 introduces FP8 E4M3/E5M2 types.
|
||||
#define JAX_CUDA_11080 (CUDA_VERSION >= 11080)
|
||||
|
||||
namespace jax {
|
||||
|
||||
using SparseHandlePool = HandlePool<cusparseHandle_t, cudaStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
|
||||
cudaStream_t stream);
|
||||
|
||||
union CudaConst {
|
||||
int8_t i8[2];
|
||||
int16_t i16[2];
|
||||
int32_t i32[2];
|
||||
int64_t i64[2];
|
||||
uint8_t u8[2];
|
||||
uint16_t u16[2];
|
||||
uint32_t u32[2];
|
||||
uint64_t u64[2];
|
||||
float f32[2];
|
||||
double f64[2];
|
||||
};
|
||||
|
||||
CudaConst CudaZero(cudaDataType type);
|
||||
CudaConst CudaOne(cudaDataType type);
|
||||
|
||||
struct SparseMatDescriptor {
|
||||
cudaDataType value_type;
|
||||
cusparseIndexType_t index_type;
|
||||
int rows, cols, nnz;
|
||||
int batch_count = 1;
|
||||
int batch_stride = 0;
|
||||
};
|
||||
|
||||
struct DenseMatDescriptor {
|
||||
cudaDataType type;
|
||||
int rows, cols;
|
||||
int batch_count = 1;
|
||||
int batch_stride = 0;
|
||||
};
|
||||
|
||||
struct DenseVecDescriptor {
|
||||
cudaDataType type;
|
||||
int size;
|
||||
};
|
||||
|
||||
#if JAX_CUSPARSE_11300
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CsrFromDense: Convert dense matrix to CSR matrix
|
||||
|
||||
void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CsrMatvec: Product of CSR matrix and dense vector.
|
||||
|
||||
struct CsrMatvecDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseVecDescriptor x, y;
|
||||
cusparseOperation_t op;
|
||||
};
|
||||
|
||||
void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CsrMatmat: Product of CSR matrix and dense matrix.
|
||||
|
||||
struct CsrMatmatDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseMatDescriptor B, C;
|
||||
cusparseOperation_t op_A;
|
||||
};
|
||||
|
||||
void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
|
||||
void CooToDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooFromDense: Convert dense matrix to COO matrix
|
||||
|
||||
void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooMatvec: Product of COO matrix and dense vector.
|
||||
|
||||
struct CooMatvecDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseVecDescriptor x, y;
|
||||
cusparseOperation_t op;
|
||||
};
|
||||
|
||||
void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooMatmat: Product of COO matrix and dense matrix.
|
||||
|
||||
struct CooMatmatDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseMatDescriptor B, C;
|
||||
cusparseOperation_t op_A;
|
||||
};
|
||||
|
||||
void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
#endif // if JAX_CUSPARSE_11300
|
||||
|
||||
struct Gtsv2Descriptor {
|
||||
int m, n, ldb;
|
||||
};
|
||||
|
||||
void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CUSPARSE_KERNELS_H_
|
43
jaxlib/gpu/BUILD
Normal file
43
jaxlib/gpu/BUILD
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Shared CUDA/ROCM GPU kernels.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
|
||||
exports_files(srcs = [
|
||||
"blas.cc",
|
||||
"blas_kernels.cc",
|
||||
"blas_kernels.h",
|
||||
"gpu_kernel_helpers.cc",
|
||||
"gpu_kernel_helpers.h",
|
||||
"gpu_kernels.cc",
|
||||
"linalg.cc",
|
||||
"lu_pivot_kernels.cc",
|
||||
"lu_pivot_kernels.cu.cc",
|
||||
"lu_pivot_kernels.h",
|
||||
"prng.cc",
|
||||
"prng_kernels.cc",
|
||||
"prng_kernels.cu.cc",
|
||||
"prng_kernels.h",
|
||||
"solver.cc",
|
||||
"solver_kernels.cc",
|
||||
"solver_kernels.h",
|
||||
"sparse.cc",
|
||||
"sparse_kernels.cc",
|
||||
"sparse_kernels.h",
|
||||
"vendor.h",
|
||||
])
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
/* Copyright 2019 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -20,28 +20,27 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/gpu/blas_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/rocm/hipblas_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipblas.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
// Converts a NumPy dtype to a Type.
|
||||
HipblasType DtypeToHipblasType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, HipblasType>({
|
||||
{{'f', 4}, HipblasType::F32},
|
||||
{{'f', 8}, HipblasType::F64},
|
||||
{{'c', 8}, HipblasType::C64},
|
||||
{{'c', 16}, HipblasType::C128},
|
||||
});
|
||||
BlasType DtypeToBlasType(const py::dtype& np_type) {
|
||||
static auto* types = new absl::flat_hash_map<std::pair<char, int>, BlasType>({
|
||||
{{'f', 4}, BlasType::F32},
|
||||
{{'f', 8}, BlasType::F64},
|
||||
{{'c', 8}, BlasType::C64},
|
||||
{{'c', 16}, BlasType::C128},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
@ -53,7 +52,7 @@ HipblasType DtypeToHipblasType(const py::dtype& np_type) {
|
||||
// Returns the descriptor for a GetrfBatched operation.
|
||||
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
|
||||
int b, int n) {
|
||||
HipblasType type = DtypeToHipblasType(dtype);
|
||||
BlasType type = DtypeToBlasType(dtype);
|
||||
size_t size = b * sizeof(void*);
|
||||
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
|
||||
}
|
||||
@ -61,23 +60,24 @@ std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
|
||||
// Returns the descriptor for a GetrfBatched operation.
|
||||
std::pair<size_t, py::bytes> BuildGeqrfBatchedDescriptor(const py::dtype& dtype,
|
||||
int b, int m, int n) {
|
||||
HipblasType type = DtypeToHipblasType(dtype);
|
||||
BlasType type = DtypeToBlasType(dtype);
|
||||
size_t size = b * sizeof(void*);
|
||||
return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})};
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
|
||||
dict["hipblas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
|
||||
dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
|
||||
dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_hipblas, m) {
|
||||
PYBIND11_MODULE(_blas, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
|
||||
m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
/* Copyright 2019 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,60 +13,61 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hipblas_kernels.h"
|
||||
#include "jaxlib/gpu/blas_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipblas.h"
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
using BlasHandlePool = HandlePool<hipblasHandle_t, hipStream_t>;
|
||||
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
|
||||
hipStream_t stream) {
|
||||
gpuStream_t stream) {
|
||||
BlasHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
hipblasHandle_t handle;
|
||||
gpublasHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCreate(&handle)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSetStream(handle, stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts a NumPy dtype to a CublasType.
|
||||
// Converts a NumPy dtype to a BlasType.
|
||||
|
||||
int SizeOfHipblasType(HipblasType type) {
|
||||
int SizeOfBlasType(BlasType type) {
|
||||
switch (type) {
|
||||
case HipblasType::F32:
|
||||
case BlasType::F32:
|
||||
return sizeof(float);
|
||||
case HipblasType::F64:
|
||||
case BlasType::F64:
|
||||
return sizeof(double);
|
||||
case HipblasType::C64:
|
||||
return sizeof(hipComplex);
|
||||
case HipblasType::C128:
|
||||
return sizeof(hipDoubleComplex);
|
||||
case BlasType::C64:
|
||||
return sizeof(gpublasComplex);
|
||||
case BlasType::C128:
|
||||
return sizeof(gpublasDoubleComplex);
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,7 +75,7 @@ int SizeOfHipblasType(HipblasType type) {
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
|
||||
static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -83,43 +84,43 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.n * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
|
||||
SizeOfHipblasType(d.type) * d.n * d.n);
|
||||
SizeOfBlasType(d.type) * d.n * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case HipblasType::F32: {
|
||||
case BlasType::F32: {
|
||||
float** batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSgetrfBatched(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::F64: {
|
||||
case BlasType::F64: {
|
||||
double** batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDgetrfBatched(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::C64: {
|
||||
hipblasComplex** batch_ptrs = static_cast<hipblasComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCgetrfBatched(
|
||||
case BlasType::C64: {
|
||||
gpublasComplex** batch_ptrs = static_cast<gpublasComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::C128: {
|
||||
hipblasDoubleComplex** batch_ptrs =
|
||||
static_cast<hipblasDoubleComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZgetrfBatched(
|
||||
case BlasType::C128: {
|
||||
gpublasDoubleComplex** batch_ptrs =
|
||||
static_cast<gpublasDoubleComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched(
|
||||
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
@ -127,7 +128,7 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -138,7 +139,7 @@ void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// Batched QR decomposition: geqrfbatched
|
||||
|
||||
static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
|
||||
static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GeqrfBatchedDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -147,56 +148,56 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[0] != buffers[1]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n,
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
std::vector<int> info(d.batch);
|
||||
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
|
||||
SizeOfHipblasType(d.type) * d.m * d.n);
|
||||
SizeOfBlasType(d.type) * d.m * d.n);
|
||||
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
|
||||
auto tau_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
|
||||
SizeOfHipblasType(d.type) * std::min(d.m, d.n));
|
||||
SizeOfBlasType(d.type) * std::min(d.m, d.n));
|
||||
JAX_RETURN_IF_ERROR(tau_ptrs_host.status());
|
||||
// TODO(phawkins): ideally we would not need to synchronize here, but to
|
||||
// avoid it we need a way to keep the host-side buffer alive until the copy
|
||||
// completes.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case HipblasType::F32: {
|
||||
case BlasType::F32: {
|
||||
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
|
||||
float** tau_batch_ptrs = static_cast<float**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipblasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::F64: {
|
||||
case BlasType::F64: {
|
||||
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
|
||||
double** tau_batch_ptrs = static_cast<double**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipblasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::C64: {
|
||||
hipblasComplex** a_batch_ptrs = static_cast<hipblasComplex**>(buffers[3]);
|
||||
hipblasComplex** tau_batch_ptrs =
|
||||
static_cast<hipblasComplex**>(buffers[4]);
|
||||
case BlasType::C64: {
|
||||
gpublasComplex** a_batch_ptrs = static_cast<gpublasComplex**>(buffers[3]);
|
||||
gpublasComplex** tau_batch_ptrs =
|
||||
static_cast<gpublasComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipblasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipblasType::C128: {
|
||||
hipblasDoubleComplex** a_batch_ptrs =
|
||||
static_cast<hipblasDoubleComplex**>(buffers[3]);
|
||||
hipblasDoubleComplex** tau_batch_ptrs =
|
||||
static_cast<hipblasDoubleComplex**>(buffers[4]);
|
||||
case BlasType::C128: {
|
||||
gpublasDoubleComplex** a_batch_ptrs =
|
||||
static_cast<gpublasDoubleComplex**>(buffers[3]);
|
||||
gpublasDoubleComplex** tau_batch_ptrs =
|
||||
static_cast<gpublasDoubleComplex**>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipblasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m,
|
||||
tau_batch_ptrs, info.data(), d.batch)));
|
||||
break;
|
||||
}
|
||||
@ -214,7 +215,7 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -223,4 +224,5 @@ void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,20 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_CUBLAS_KERNELS_H_
|
||||
#define JAXLIB_CUBLAS_KERNELS_H_
|
||||
#ifndef JAXLIB_GPU_BLAS_KERNELS_H_
|
||||
#define JAXLIB_GPU_BLAS_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// Set of types known to Cusolver.
|
||||
enum class CublasType {
|
||||
enum class BlasType {
|
||||
F32,
|
||||
F64,
|
||||
C64,
|
||||
@ -36,25 +35,24 @@ enum class CublasType {
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
struct GetrfBatchedDescriptor {
|
||||
CublasType type;
|
||||
BlasType type;
|
||||
int batch, n;
|
||||
};
|
||||
|
||||
void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
|
||||
// Batched QR decomposition: geqrfbatched
|
||||
|
||||
struct GeqrfBatchedDescriptor {
|
||||
CublasType type;
|
||||
BlasType type;
|
||||
int batch, m, n;
|
||||
};
|
||||
|
||||
void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CUBLAS_KERNELS_H_
|
||||
#endif // JAXLIB_GPU_BLAS_KERNELS_H_
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
/* Copyright 2019 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,17 +13,83 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
|
||||
#include <stdexcept>
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace {
|
||||
std::string ErrorString(hipError_t error) { return hipGetErrorString(error); }
|
||||
std::string ErrorString(gpuError_t error) { return gpuGetErrorString(error); }
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
std::string ErrorString(gpusparseStatus_t status) {
|
||||
return cusparseGetErrorString(status);
|
||||
}
|
||||
|
||||
std::string ErrorString(gpusolverStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSOLVER_STATUS_SUCCESS:
|
||||
return "cuSolver success.";
|
||||
case CUSOLVER_STATUS_NOT_INITIALIZED:
|
||||
return "cuSolver has not been initialized";
|
||||
case CUSOLVER_STATUS_ALLOC_FAILED:
|
||||
return "cuSolver allocation failed";
|
||||
case CUSOLVER_STATUS_INVALID_VALUE:
|
||||
return "cuSolver invalid value error";
|
||||
case CUSOLVER_STATUS_ARCH_MISMATCH:
|
||||
return "cuSolver architecture mismatch error";
|
||||
case CUSOLVER_STATUS_MAPPING_ERROR:
|
||||
return "cuSolver mapping error";
|
||||
case CUSOLVER_STATUS_EXECUTION_FAILED:
|
||||
return "cuSolver execution failed";
|
||||
case CUSOLVER_STATUS_INTERNAL_ERROR:
|
||||
return "cuSolver internal error";
|
||||
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
return "cuSolver matrix type not supported error";
|
||||
case CUSOLVER_STATUS_NOT_SUPPORTED:
|
||||
return "cuSolver not supported error";
|
||||
case CUSOLVER_STATUS_ZERO_PIVOT:
|
||||
return "cuSolver zero pivot error";
|
||||
case CUSOLVER_STATUS_INVALID_LICENSE:
|
||||
return "cuSolver invalid license error";
|
||||
default:
|
||||
return absl::StrCat("Unknown cuSolver error: ", status);
|
||||
}
|
||||
}
|
||||
|
||||
std::string ErrorString(gpublasStatus_t status) {
|
||||
switch (status) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "cuBlas success";
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "cuBlas has not been initialized";
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "cuBlas allocation failure";
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "cuBlas invalid value error";
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "cuBlas architecture mismatch";
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "cuBlas mapping error";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "cuBlas execution failed";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "cuBlas internal error";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "cuBlas not supported error";
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "cuBlas license error";
|
||||
default:
|
||||
return "Unknown cuBlas error";
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::string ErrorString(hipsparseStatus_t status) {
|
||||
// TODO(reza): check and see if we can use hipify
|
||||
@ -115,6 +181,8 @@ std::string ErrorString(hipblasStatus_t status) {
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
std::string ErrorString(T status, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
@ -123,37 +191,37 @@ std::string ErrorString(T status, const char* file, std::int64_t line,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line,
|
||||
absl::Status AsStatus(gpuError_t error, const char* file, std::int64_t line,
|
||||
const char* expr) {
|
||||
if (error != hipSuccess)
|
||||
if (error != gpuSuccess)
|
||||
return absl::InternalError(ErrorString(error, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AsStatus(hipsolverStatus_t status, const char* file,
|
||||
absl::Status AsStatus(gpusolverStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != HIPSOLVER_STATUS_SUCCESS)
|
||||
if (status != GPUSOLVER_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AsStatus(hipsparseStatus_t status, const char* file,
|
||||
absl::Status AsStatus(gpusparseStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != HIPSPARSE_STATUS_SUCCESS)
|
||||
if (status != GPUSPARSE_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AsStatus(hipblasStatus_t status, const char* file,
|
||||
absl::Status AsStatus(gpublasStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != HIPBLAS_STATUS_SUCCESS)
|
||||
if (status != GPUBLAS_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<void* []>>
|
||||
MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(
|
||||
gpuStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
char* ptr = static_cast<char*>(buffer);
|
||||
auto host_ptrs = absl::make_unique<void*[]>(batch);
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
@ -161,8 +229,10 @@ MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
ptr += batch_elem_size;
|
||||
}
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
|
||||
hipMemcpyHostToDevice, stream)));
|
||||
gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
|
||||
gpuMemcpyHostToDevice, stream)));
|
||||
return std::move(host_ptrs);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,19 +13,17 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_
|
||||
#define JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_
|
||||
#ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_
|
||||
#define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
|
||||
#define JAX_AS_STATUS(expr) \
|
||||
jax::JAX_GPU_NAMESPACE::AsStatus(expr, __FILE__, __LINE__, #expr)
|
||||
|
||||
#define JAX_THROW_IF_ERROR(expr) \
|
||||
{ \
|
||||
@ -40,27 +38,29 @@ limitations under the License.
|
||||
}
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// Used via JAX_AS_STATUS(expr) macro.
|
||||
absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line,
|
||||
absl::Status AsStatus(gpuError_t error, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
absl::Status AsStatus(cusolverStatus_t status, const char* file,
|
||||
absl::Status AsStatus(gpusolverStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(cusparseStatus_t status, const char* file,
|
||||
absl::Status AsStatus(gpusparseStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(cublasStatus_t status, const char* file,
|
||||
absl::Status AsStatus(gpublasStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
|
||||
// Builds an array of pointers to each array in a batch, in device memory.
|
||||
// Caution: the return value must be kept alive (e.g., via a stream
|
||||
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
|
||||
// completes.
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(cudaStream_t stream,
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(gpuStream_t stream,
|
||||
void* buffer,
|
||||
void* dev_ptrs,
|
||||
int batch,
|
||||
int batch_elem_size);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_
|
||||
#endif // JAXLIB_GPU_GPU_KERNEL_HELPERS_H_
|
@ -16,21 +16,23 @@ limitations under the License.
|
||||
// This file is not used by JAX itself, but exists to assist with running
|
||||
// JAX-generated HLO code from outside of JAX.
|
||||
|
||||
#include "jaxlib/cuda/cublas_kernels.h"
|
||||
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
|
||||
#include "jaxlib/cuda/cuda_prng_kernels.h"
|
||||
#include "jaxlib/cuda/cusolver_kernels.h"
|
||||
#include "jaxlib/cuda/cusparse_kernels.h"
|
||||
#include "jaxlib/gpu/blas_kernels.h"
|
||||
#include "jaxlib/gpu/lu_pivot_kernels.h"
|
||||
#include "jaxlib/gpu/prng_kernels.h"
|
||||
#include "jaxlib/gpu/solver_kernels.h"
|
||||
#include "jaxlib/gpu/sparse_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation",
|
||||
CudaLuPivotsToPermutation, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_threefry2x32", CudaThreeFry2x32,
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_lu_pivots_to_permutation",
|
||||
LuPivotsToPermutation, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32,
|
||||
"CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_potrf", Potrf, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA");
|
||||
@ -66,4 +68,5 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64,
|
||||
"CUDA");
|
||||
|
||||
} // namespace
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/cuda/cuda_lu_pivot_kernels.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/lu_pivot_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
std::string BuildCudaLuPivotsToPermutationDescriptor(
|
||||
std::string BuildLuPivotsToPermutationDescriptor(
|
||||
std::int64_t batch_size, std::int32_t pivot_size,
|
||||
std::int32_t permutation_size) {
|
||||
return PackDescriptorAsString(LuPivotsToPermutationDescriptor{
|
||||
@ -31,21 +31,22 @@ std::string BuildCudaLuPivotsToPermutationDescriptor(
|
||||
|
||||
pybind11::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["cuda_lu_pivots_to_permutation"] =
|
||||
EncapsulateFunction(CudaLuPivotsToPermutation);
|
||||
dict[JAX_GPU_PREFIX "_lu_pivots_to_permutation"] =
|
||||
EncapsulateFunction(LuPivotsToPermutation);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_cuda_linalg, m) {
|
||||
PYBIND11_MODULE(_linalg, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("lu_pivots_to_permutation_descriptor",
|
||||
[](std::int64_t batch_size, std::int32_t pivot_size,
|
||||
std::int32_t permutation_size) {
|
||||
std::string result = BuildCudaLuPivotsToPermutationDescriptor(
|
||||
std::string result = BuildLuPivotsToPermutationDescriptor(
|
||||
batch_size, pivot_size, permutation_size);
|
||||
return pybind11::bytes(result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,38 +13,41 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hip_lu_pivot_kernels.h"
|
||||
#include "jaxlib/gpu/lu_pivot_kernels.h"
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
absl::Status HipLuPivotsToPermutation_(hipStream_t stream, void** buffers,
|
||||
const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
absl::Status LuPivotsToPermutation_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
auto s =
|
||||
UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
LaunchLuPivotsToPermutationKernel(stream, buffers, **s);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError()));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void HipLuPivotsToPermutation(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len,
|
||||
XlaCustomCallStatus* status) {
|
||||
auto s = HipLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
|
||||
void LuPivotsToPermutation(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len,
|
||||
XlaCustomCallStatus* status) {
|
||||
auto s = LuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
std::string_view message = s.message();
|
||||
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hip_lu_pivot_kernels.h"
|
||||
#include "jaxlib/gpu/lu_pivot_kernels.h"
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
__device__ void ComputePermutation(const std::int32_t* pivots,
|
||||
@ -58,7 +61,7 @@ __global__ void LuPivotsToPermutationKernel(
|
||||
} // namespace
|
||||
|
||||
void LaunchLuPivotsToPermutationKernel(
|
||||
hipStream_t stream, void** buffers,
|
||||
gpuStream_t stream, void** buffers,
|
||||
LuPivotsToPermutationDescriptor descriptor) {
|
||||
const std::int32_t* pivots =
|
||||
reinterpret_cast<const std::int32_t*>(buffers[0]);
|
||||
@ -74,4 +77,5 @@ void LaunchLuPivotsToPermutationKernel(
|
||||
descriptor.permutation_size);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,16 +13,17 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_HIP_LU_PIVOT_KERNELS_H_
|
||||
#define JAXLIB_HIP_LU_PIVOT_KERNELS_H_
|
||||
#ifndef JAXLIB_GPU_LU_PIVOT_KERNELS_H_
|
||||
#define JAXLIB_GPU_LU_PIVOT_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
struct LuPivotsToPermutationDescriptor {
|
||||
std::int64_t batch_size;
|
||||
@ -31,13 +32,14 @@ struct LuPivotsToPermutationDescriptor {
|
||||
};
|
||||
|
||||
void LaunchLuPivotsToPermutationKernel(
|
||||
hipStream_t stream, void** buffers,
|
||||
gpuStream_t stream, void** buffers,
|
||||
LuPivotsToPermutationDescriptor descriptor);
|
||||
|
||||
void HipLuPivotsToPermutation(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len,
|
||||
XlaCustomCallStatus* status);
|
||||
void LuPivotsToPermutation(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len,
|
||||
XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_HIP_LU_PIVOT_KERNELS_H_
|
||||
#endif // JAXLIB_GPU_LU_PIVOT_KERNELS_H_
|
@ -13,31 +13,32 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cuda_prng_kernels.h"
|
||||
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/prng_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) {
|
||||
std::string BuildThreeFry2x32Descriptor(std::int64_t n) {
|
||||
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
|
||||
}
|
||||
pybind11::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["cuda_threefry2x32"] = EncapsulateFunction(CudaThreeFry2x32);
|
||||
dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_cuda_prng, m) {
|
||||
PYBIND11_MODULE(_prng, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("threefry2x32_descriptor", [](std::int64_t n) {
|
||||
std::string result = BuildCudaThreeFry2x32Descriptor(n);
|
||||
std::string result = BuildThreeFry2x32Descriptor(n);
|
||||
return pybind11::bytes(result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,35 +13,37 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cuda_prng_kernels.h"
|
||||
#include "jaxlib/gpu/prng_kernels.h"
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
absl::Status CudaThreeFry2x32_(cudaStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
LaunchThreeFry2x32Kernel(stream, buffers, **s);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError()));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len);
|
||||
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
std::string_view message = s.message();
|
||||
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cuda/cuda_prng_kernels.h"
|
||||
#include "jaxlib/gpu/prng_kernels.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
__global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
|
||||
@ -96,7 +99,7 @@ __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0,
|
||||
|
||||
} // namespace
|
||||
|
||||
void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers,
|
||||
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor) {
|
||||
std::array<const std::uint32_t*, 2> keys;
|
||||
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
|
||||
@ -115,4 +118,5 @@ void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers,
|
||||
out[1], descriptor.n);
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,27 +13,29 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_CUDA_PRNG_KERNELS_H_
|
||||
#define JAXLIB_CUDA_PRNG_KERNELS_H_
|
||||
#ifndef JAXLIB_GPU_PRNG_KERNELS_H_
|
||||
#define JAXLIB_GPU_PRNG_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
struct ThreeFry2x32Descriptor {
|
||||
std::int64_t n;
|
||||
};
|
||||
|
||||
void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers,
|
||||
void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor);
|
||||
|
||||
void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CUDA_PRNG_KERNELS_H_
|
||||
#endif // JAXLIB_GPU_PRNG_KERNELS_H_
|
@ -21,28 +21,27 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/cuda/cusolver_kernels.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/solver_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
namespace py = pybind11;
|
||||
|
||||
// Converts a NumPy dtype to a Type.
|
||||
CusolverType DtypeToCusolverType(const py::dtype& np_type) {
|
||||
SolverType DtypeToSolverType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, CusolverType>({
|
||||
{{'f', 4}, CusolverType::F32},
|
||||
{{'f', 8}, CusolverType::F64},
|
||||
{{'c', 8}, CusolverType::C64},
|
||||
{{'c', 16}, CusolverType::C128},
|
||||
new absl::flat_hash_map<std::pair<char, int>, SolverType>({
|
||||
{{'f', 4}, SolverType::F32},
|
||||
{{'f', 8}, SolverType::F64},
|
||||
{{'c', 8}, SolverType::C64},
|
||||
{{'c', 16}, SolverType::C128},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
@ -57,48 +56,87 @@ CusolverType DtypeToCusolverType(const py::dtype& np_type) {
|
||||
// Returns the workspace size and a descriptor for a potrf operation.
|
||||
std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
std::int64_t workspace_size;
|
||||
cublasFillMode_t uplo =
|
||||
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
gpusolverFillMode_t uplo =
|
||||
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
|
||||
if (b == 1) {
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnSpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnSpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(float);
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnDpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnDpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(double);
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnCpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(cuComplex);
|
||||
JAX_AS_STATUS(gpusolverDnCpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(gpuComplex);
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnZpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(cuDoubleComplex);
|
||||
JAX_AS_STATUS(gpusolverDnZpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(gpuDoubleComplex);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
#ifdef JAX_GPU_CUDA
|
||||
// We use the workspace buffer for our own scratch space.
|
||||
workspace_size = sizeof(void*) * b;
|
||||
#else
|
||||
// TODO(rocm): when cuda and hip had same API for batched potrf, remove this
|
||||
// batched potrf has different API compared to CUDA. In hip we still need to
|
||||
// create the workspace and additional space to copy the batch array
|
||||
// pointers
|
||||
switch (type) {
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size = (lwork * sizeof(float)) + (b * sizeof(float*));
|
||||
break;
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size = (lwork * sizeof(double)) + (b * sizeof(double*));
|
||||
break;
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverCpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size =
|
||||
(lwork * sizeof(hipComplex)) + (b * sizeof(hipComplex*));
|
||||
break;
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverZpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size = (lwork * sizeof(hipDoubleComplex)) +
|
||||
(b * sizeof(hipDoubleComplex*));
|
||||
break;
|
||||
}
|
||||
#endif // JAX_GPU_CUDA
|
||||
}
|
||||
return {workspace_size,
|
||||
PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})};
|
||||
@ -109,38 +147,38 @@ std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
|
||||
// Returns the workspace size and a descriptor for a getrf operation.
|
||||
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnSgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnDgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnCgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnZgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})};
|
||||
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})};
|
||||
}
|
||||
|
||||
// geqrf: QR decomposition
|
||||
@ -148,91 +186,91 @@ std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
|
||||
// Returns the workspace size and a descriptor for a geqrf operation.
|
||||
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnSgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnDgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnCgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnZgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})};
|
||||
}
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
// csrlsvqr: Linear system solve via Sparse QR
|
||||
|
||||
// Returns a descriptor for a csrlsvqr operation.
|
||||
py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA,
|
||||
int reorder, double tol) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
auto h = SpSolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol});
|
||||
}
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
// orgqr/ungqr: apply elementary Householder transformations
|
||||
|
||||
// Returns the workspace size and a descriptor for a geqrf operation.
|
||||
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n, int k) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnSorgqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnDorgqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnCungqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(cusolverDnZungqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})};
|
||||
@ -243,32 +281,32 @@ std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
|
||||
// Returns the workspace size and a descriptor for a syevd operation.
|
||||
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
|
||||
cublasFillMode_t uplo =
|
||||
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
|
||||
gpusolverFillMode_t uplo =
|
||||
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevd_bufferSize(
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevd_bufferSize(
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevd_bufferSize(
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevd_bufferSize(
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork)));
|
||||
break;
|
||||
@ -282,60 +320,60 @@ std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
|
||||
// Returns the workspace size and a descriptor for a syevj_batched operation.
|
||||
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
|
||||
bool lower, int batch, int n) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
syevjInfo_t params;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
|
||||
params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); });
|
||||
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
|
||||
cublasFillMode_t uplo =
|
||||
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
gpuSyevjInfo_t params;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
|
||||
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
|
||||
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
|
||||
gpusolverFillMode_t uplo =
|
||||
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
|
||||
if (batch == 1) {
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevj_bufferSize(
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevj_bufferSize(
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevj_bufferSize(
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevj_bufferSize(
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevjBatched_bufferSize(
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevjBatched_bufferSize(
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevjBatched_bufferSize(
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevjBatched_bufferSize(
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
@ -350,29 +388,11 @@ std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
|
||||
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n, bool compute_uv,
|
||||
bool full_matrices) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork)));
|
||||
break;
|
||||
}
|
||||
signed char jobu, jobvt;
|
||||
if (compute_uv) {
|
||||
if (full_matrices) {
|
||||
@ -383,51 +403,71 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
|
||||
} else {
|
||||
jobu = jobvt = 'N';
|
||||
}
|
||||
switch (type) {
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd_bufferSize(
|
||||
handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd_bufferSize(
|
||||
handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd_bufferSize(
|
||||
handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd_bufferSize(
|
||||
handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork,
|
||||
PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})};
|
||||
}
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
// Singular value decomposition using Jacobi algorithm: gesvdj
|
||||
|
||||
// Returns the workspace size and a descriptor for a gesvdj operation.
|
||||
std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
int batch, int m, int n,
|
||||
bool compute_uv, int econ) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
cusolverEigMode_t jobz =
|
||||
compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
|
||||
gpusolverEigMode_t jobz =
|
||||
compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
|
||||
gesvdjInfo_t params;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms)));
|
||||
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
|
||||
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
|
||||
if (batch == 1) {
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize(
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize(
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize(
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize(
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
@ -437,28 +477,28 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
}
|
||||
} else {
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched_bufferSize(
|
||||
handle.get(), jobz, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params, batch)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched_bufferSize(
|
||||
handle.get(), jobz, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params, batch)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched_bufferSize(
|
||||
handle.get(), jobz, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params, batch)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched_bufferSize(
|
||||
handle.get(), jobz, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
@ -471,32 +511,40 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})};
|
||||
}
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["cusolver_potrf"] = EncapsulateFunction(Potrf);
|
||||
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
|
||||
dict["cusolver_geqrf"] = EncapsulateFunction(Geqrf);
|
||||
dict[JAX_GPU_PREFIX "solver_potrf"] = EncapsulateFunction(Potrf);
|
||||
dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf);
|
||||
dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf);
|
||||
dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr);
|
||||
dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd);
|
||||
dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj);
|
||||
dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr);
|
||||
dict["cusolver_orgqr"] = EncapsulateFunction(Orgqr);
|
||||
dict["cusolver_syevd"] = EncapsulateFunction(Syevd);
|
||||
dict["cusolver_syevj"] = EncapsulateFunction(Syevj);
|
||||
dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd);
|
||||
dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj);
|
||||
#endif // JAX_GPU_CUDA
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_cusolver, m) {
|
||||
PYBIND11_MODULE(_solver, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_potrf_descriptor", &BuildPotrfDescriptor);
|
||||
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
|
||||
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
|
||||
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
|
||||
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
|
||||
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
|
||||
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
|
||||
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
|
||||
#ifdef JAX_GPU_CUDA
|
||||
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
|
||||
m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor);
|
||||
#endif // JAX_GPU_CUDA
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
File diff suppressed because it is too large
Load Diff
@ -17,28 +17,36 @@ limitations under the License.
|
||||
#define JAXLIB_CUSOLVER_KERNELS_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
namespace jax {
|
||||
|
||||
using SolverHandlePool = HandlePool<cusolverDnHandle_t, cudaStream_t>;
|
||||
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, cudaStream_t>;
|
||||
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
|
||||
cudaStream_t stream);
|
||||
gpuStream_t stream);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<SpSolverHandlePool::Handle> SpSolverHandlePool::Borrow(
|
||||
cudaStream_t stream);
|
||||
gpuStream_t stream);
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// Set of types known to Cusolver.
|
||||
enum class CusolverType {
|
||||
enum class SolverType {
|
||||
F32,
|
||||
F64,
|
||||
C64,
|
||||
@ -48,105 +56,114 @@ enum class CusolverType {
|
||||
// potrf: Cholesky decomposition
|
||||
|
||||
struct PotrfDescriptor {
|
||||
CusolverType type;
|
||||
cublasFillMode_t uplo;
|
||||
SolverType type;
|
||||
gpusolverFillMode_t uplo;
|
||||
std::int64_t batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
void Potrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Potrf(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
// getrf: LU decomposition
|
||||
|
||||
struct GetrfDescriptor {
|
||||
CusolverType type;
|
||||
int batch, m, n;
|
||||
SolverType type;
|
||||
int batch, m, n, lwork;
|
||||
};
|
||||
|
||||
void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Getrf(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// geqrf: QR decomposition
|
||||
|
||||
struct GeqrfDescriptor {
|
||||
CusolverType type;
|
||||
SolverType type;
|
||||
int batch, m, n, lwork;
|
||||
};
|
||||
|
||||
void Geqrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Geqrf(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
// csrlsvpr: Linear system solve via Sparse QR
|
||||
|
||||
struct CsrlsvqrDescriptor {
|
||||
CusolverType type;
|
||||
SolverType type;
|
||||
int n, nnz, reorder;
|
||||
double tol;
|
||||
};
|
||||
|
||||
void Csrlsvqr(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
// orgqr/ungqr: apply elementary Householder transformations
|
||||
|
||||
struct OrgqrDescriptor {
|
||||
CusolverType type;
|
||||
SolverType type;
|
||||
int batch, m, n, k, lwork;
|
||||
};
|
||||
|
||||
void Orgqr(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Orgqr(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
|
||||
|
||||
struct SyevdDescriptor {
|
||||
CusolverType type;
|
||||
cublasFillMode_t uplo;
|
||||
SolverType type;
|
||||
gpusolverFillMode_t uplo;
|
||||
int batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Syevd(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
struct SyevjDescriptor {
|
||||
CusolverType type;
|
||||
cublasFillMode_t uplo;
|
||||
SolverType type;
|
||||
gpusolverFillMode_t uplo;
|
||||
int batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
void Syevj(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Syevj(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Singular value decomposition using QR algorithm: gesvd
|
||||
|
||||
struct GesvdDescriptor {
|
||||
CusolverType type;
|
||||
SolverType type;
|
||||
int batch, m, n;
|
||||
int lwork;
|
||||
signed char jobu, jobvt;
|
||||
};
|
||||
|
||||
void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Gesvd(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
|
||||
// Singular value decomposition using Jacobi algorithm: gesvdj
|
||||
|
||||
struct GesvdjDescriptor {
|
||||
CusolverType type;
|
||||
SolverType type;
|
||||
int batch, m, n;
|
||||
int lwork;
|
||||
cusolverEigMode_t jobz;
|
||||
gpusolverEigMode_t jobz;
|
||||
int econ;
|
||||
};
|
||||
|
||||
void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CUSOLVER_KERNELS_H_
|
@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "rocm/include/hipsparse.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
@ -26,10 +24,9 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "rocm/include/hip/hip_complex.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/rocm/hipsparse_kernels.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/sparse_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
@ -38,14 +35,15 @@ limitations under the License.
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
hipsparseIndexType_t DtypeToHipSparseIndexType(const py::dtype& np_type) {
|
||||
gpusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, hipsparseIndexType_t>({
|
||||
{{'u', 2}, HIPSPARSE_INDEX_16U},
|
||||
{{'i', 4}, HIPSPARSE_INDEX_32I},
|
||||
{{'i', 8}, HIPSPARSE_INDEX_64I},
|
||||
new absl::flat_hash_map<std::pair<char, int>, gpusparseIndexType_t>({
|
||||
{{'u', 2}, GPUSPARSE_INDEX_16U},
|
||||
{{'i', 4}, GPUSPARSE_INDEX_32I},
|
||||
{{'i', 8}, GPUSPARSE_INDEX_64I},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
@ -55,16 +53,20 @@ hipsparseIndexType_t DtypeToHipSparseIndexType(const py::dtype& np_type) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// TODO(rocm): add more hip data types when supported
|
||||
hipDataType DtypeToHipDataType(const py::dtype& np_type) {
|
||||
gpuDataType DtypeToCudaDataType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, hipDataType>(
|
||||
{{{'f', 2}, HIP_R_16F},
|
||||
{{'c', 4}, HIP_C_16F},
|
||||
{{'f', 4}, HIP_R_32F},
|
||||
{{'c', 8}, HIP_C_32F},
|
||||
{{'f', 8}, HIP_R_64F},
|
||||
{{'c', 16}, HIP_C_64F}});
|
||||
new absl::flat_hash_map<std::pair<char, int>, gpuDataType>({
|
||||
{{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F},
|
||||
{{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F},
|
||||
{{'c', 16}, GPU_C_64F},
|
||||
#ifdef JAX_GPU_CUDA
|
||||
{{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U},
|
||||
{{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U},
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
{{'V', 2}, CUDA_R_16BF},
|
||||
#endif // JAX_GPU_HAVE_SPARSE
|
||||
#endif // JAX_GPU_CUDA
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
@ -78,28 +80,28 @@ SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
|
||||
int rows, int cols, int nnz,
|
||||
int batch_count,
|
||||
int batch_stride) {
|
||||
hipDataType value_type = DtypeToHipDataType(data_dtype);
|
||||
hipsparseIndexType_t index_type = DtypeToHipSparseIndexType(index_dtype);
|
||||
return SparseMatDescriptor{
|
||||
value_type, index_type, rows, cols, nnz, batch_count, batch_stride};
|
||||
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype);
|
||||
return SparseMatDescriptor{value_type, index_type, rows, cols,
|
||||
nnz, batch_count, batch_stride};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense matrix.
|
||||
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
|
||||
int rows, int cols, int batch_count,
|
||||
int batch_stride) {
|
||||
hipDataType value_type = DtypeToHipDataType(data_dtype);
|
||||
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride};
|
||||
}
|
||||
|
||||
// Returns the descriptor for a Dense vector.
|
||||
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
||||
int size) {
|
||||
hipDataType value_type = DtypeToHipDataType(data_dtype);
|
||||
gpuDataType value_type = DtypeToCudaDataType(data_dtype);
|
||||
return DenseVecDescriptor{value_type, size};
|
||||
}
|
||||
|
||||
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
@ -111,35 +113,35 @@ std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count*/1, /*batch_stride*/0);
|
||||
/*batch_count*/ 1, /*batch_stride*/ 0);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
|
||||
// buffer_size does not reference these pointers, but does error on NULL.
|
||||
// TODO(jakevdp): check whether this is documented.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
||||
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
absl::Status CsrToDense_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SparseMatDescriptor& d = **s;
|
||||
@ -147,28 +149,28 @@ absl::Status CsrToDense_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type, d.index_type,
|
||||
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type, d.index_type,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
gpusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -190,30 +192,30 @@ std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
|
||||
hipsparseDnMatDescr_t mat_a = 0;
|
||||
hipsparseSpMatDescr_t mat_b = 0;
|
||||
gpusparseDnMatDescr_t mat_a = 0;
|
||||
gpusparseSpMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
|
||||
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
||||
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
|
||||
absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
|
||||
absl::Status CsrFromDense_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -222,29 +224,29 @@ absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
hipsparseDnMatDescr_t mat_a = 0;
|
||||
hipsparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
gpusparseDnMatDescr_t mat_a = 0;
|
||||
gpusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type, d.index_type,
|
||||
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type, d.index_type,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -271,32 +273,32 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
||||
DenseVecDescriptor y =
|
||||
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnVecDescr_t vec_x = 0;
|
||||
hipsparseDnVecDescr_t vec_y = 0;
|
||||
hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnVecDescr_t vec_x = 0;
|
||||
gpusparseDnVecDescr_t vec_y = 0;
|
||||
gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
||||
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
HipConst alpha = HipOne(y.type);
|
||||
HipConst beta = HipZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize(
|
||||
SparseConst alpha = ConstOne(y.type);
|
||||
SparseConst beta = ConstZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
HIPSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
@ -320,35 +322,35 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
DenseMatDescriptor C =
|
||||
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
||||
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
hipsparseDnMatDescr_t mat_c = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, GPUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
HipConst alpha = HipOne(C.type);
|
||||
HipConst beta = HipZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize(
|
||||
handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
SparseConst alpha = ConstOne(C.type);
|
||||
SparseConst beta = ConstZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
@ -366,26 +368,26 @@ std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty,
|
||||
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
||||
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
@ -403,25 +405,25 @@ std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
/*batch_count=*/1, /*batch_stride=*/0);
|
||||
|
||||
hipsparseDnMatDescr_t mat_a = 0;
|
||||
hipsparseSpMatDescr_t mat_b = 0;
|
||||
gpusparseDnMatDescr_t mat_a = 0;
|
||||
gpusparseSpMatDescr_t mat_b = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty,
|
||||
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
/*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
||||
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
size_t buffer_size;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
&buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
||||
|
||||
return {buffer_size, PackDescriptor(d)};
|
||||
}
|
||||
@ -444,32 +446,32 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
||||
DenseVecDescriptor y =
|
||||
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnVecDescr_t vec_x = 0;
|
||||
hipsparseDnVecDescr_t vec_y = 0;
|
||||
hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnVecDescr_t vec_x = 0;
|
||||
gpusparseDnVecDescr_t vec_y = 0;
|
||||
gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
||||
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
|
||||
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type)));
|
||||
size_t buffer_size;
|
||||
HipConst alpha = HipOne(y.type);
|
||||
HipConst beta = HipZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize(
|
||||
SparseConst alpha = ConstOne(y.type);
|
||||
SparseConst beta = ConstZero(y.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
|
||||
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
|
||||
HIPSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
|
||||
}
|
||||
@ -490,63 +492,61 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz,
|
||||
batch_count, lhs_batch_stride);
|
||||
DenseMatDescriptor B =
|
||||
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols,
|
||||
batch_count, rhs_batch_stride);
|
||||
|
||||
SparseMatDescriptor A = BuildSparseMatDescriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz, batch_count, lhs_batch_stride);
|
||||
DenseMatDescriptor B = BuildDenseMatDescriptor(
|
||||
b_dtype, transpose ? rows : cols, BCcols, batch_count, rhs_batch_stride);
|
||||
int C_rows = (transpose == true) ? cols : rows;
|
||||
// TODO(tianjianlu): enable the selection of batch stride.
|
||||
// The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
|
||||
// The issue
|
||||
// (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
|
||||
// in cusparse library does not allow batch_stride = 0.
|
||||
// int C_batch_stride = (batch_count > 1)? C_rows * BCcols : 0;
|
||||
int C_batch_stride = C_rows * BCcols;
|
||||
DenseMatDescriptor C =
|
||||
BuildDenseMatDescriptor(compute_dtype, /*rows=*/C_rows, /*cols=*/BCcols,
|
||||
batch_count, C_batch_stride);
|
||||
hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE
|
||||
: GPUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
hipsparseDnMatDescr_t mat_c = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
// bufferSize does not reference these pointers, but does error on NULL.
|
||||
int val = 0;
|
||||
void* empty = &val;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
|
||||
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCooSetStridedBatch(
|
||||
mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
||||
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, A.value_type)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCooSetStridedBatch(
|
||||
mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseDnMatSetStridedBatch(
|
||||
mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
|
||||
empty, B.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
|
||||
mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride)));
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseDnMatSetStridedBatch(
|
||||
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
|
||||
empty, C.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch(
|
||||
mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride)));
|
||||
size_t buffer_size;
|
||||
HipConst alpha = HipOne(C.type);
|
||||
HipConst beta = HipZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize(
|
||||
handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
SparseConst alpha = ConstOne(C.type);
|
||||
SparseConst beta = ConstZero(C.type);
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
|
||||
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
|
||||
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
|
||||
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
|
||||
|
||||
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
|
||||
}
|
||||
|
||||
#endif // if JAX_GPU_HAVE_SPARSE
|
||||
|
||||
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
|
||||
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
|
||||
@ -565,32 +565,37 @@ size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
|
||||
}
|
||||
|
||||
size_t Gtsv2BufferSizeF32(int m, int n, int ldb) {
|
||||
return Gtsv2BufferSize(hipsparseSgtsv2_bufferSizeExt, m, n, ldb);
|
||||
return Gtsv2BufferSize(gpusparseSgtsv2_bufferSizeExt, m, n, ldb);
|
||||
}
|
||||
|
||||
size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
|
||||
return Gtsv2BufferSize(hipsparseDgtsv2_bufferSizeExt, m, n, ldb);
|
||||
return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb);
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["hipsparse_csr_todense"] = EncapsulateFunction(CsrToDense);
|
||||
dict["hipsparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
|
||||
dict["hipsparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
|
||||
dict["hipsparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
|
||||
dict["hipsparse_coo_todense"] = EncapsulateFunction(CooToDense);
|
||||
dict["hipsparse_coo_fromdense"] = EncapsulateFunction(CooFromDense);
|
||||
dict["hipsparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
|
||||
dict["hipsparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
|
||||
dict["hipsparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
|
||||
dict["hipsparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense);
|
||||
dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] =
|
||||
EncapsulateFunction(CsrFromDense);
|
||||
dict[JAX_GPU_PREFIX "sparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
|
||||
dict[JAX_GPU_PREFIX "sparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
|
||||
dict[JAX_GPU_PREFIX "sparse_coo_todense"] = EncapsulateFunction(CooToDense);
|
||||
dict[JAX_GPU_PREFIX "sparse_coo_fromdense"] =
|
||||
EncapsulateFunction(CooFromDense);
|
||||
dict[JAX_GPU_PREFIX "sparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
|
||||
dict[JAX_GPU_PREFIX "sparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
|
||||
#endif
|
||||
dict[JAX_GPU_PREFIX "sparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
|
||||
dict[JAX_GPU_PREFIX "sparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
|
||||
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_hipsparse, m) {
|
||||
m.attr("hipsparse_supported") = py::bool_(true);
|
||||
PYBIND11_MODULE(_sparse, m) {
|
||||
m.attr("sparse_supported") = py::bool_(JAX_GPU_HAVE_SPARSE);
|
||||
m.def("registrations", &Registrations);
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
|
||||
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
|
||||
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
|
||||
@ -599,10 +604,12 @@ PYBIND11_MODULE(_hipsparse, m) {
|
||||
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
|
||||
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
|
||||
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
|
||||
#endif
|
||||
m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32);
|
||||
m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64);
|
||||
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hipsparse_kernels.h"
|
||||
#include "jaxlib/gpu/sparse_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
@ -24,62 +24,128 @@ limitations under the License.
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "rocm/include/hip/hip_complex.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SparseHandlePool::Handle>
|
||||
SparseHandlePool::Borrow(hipStream_t stream) {
|
||||
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
|
||||
gpuStream_t stream) {
|
||||
SparseHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
hipsparseHandle_t handle;
|
||||
gpusparseHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreate(&handle)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSetStream(handle, stream)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
HipConst HipZero(hipDataType type) {
|
||||
HipConst c;
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
SparseConst ConstZero(gpuDataType type) {
|
||||
SparseConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
return c;
|
||||
}
|
||||
|
||||
HipConst HipOne(hipDataType type) {
|
||||
HipConst c;
|
||||
SparseConst ConstOne(gpuDataType type) {
|
||||
SparseConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
// TODO(rocm): add more data type if new rocm support
|
||||
switch (type) {
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
|
||||
case CUDA_R_4I:
|
||||
case CUDA_C_4I:
|
||||
#endif
|
||||
case CUDA_R_8I:
|
||||
case CUDA_C_8I:
|
||||
c.i8[0] = 1;
|
||||
break;
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
case CUDA_R_4U:
|
||||
case CUDA_C_4U:
|
||||
#endif
|
||||
case CUDA_R_8U:
|
||||
case CUDA_C_8U:
|
||||
c.u8[0] = 1;
|
||||
break;
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
case CUDA_R_16I:
|
||||
case CUDA_C_16I:
|
||||
c.i16[0] = 1;
|
||||
break;
|
||||
case CUDA_R_16U:
|
||||
case CUDA_C_16U:
|
||||
c.u16[0] = 1;
|
||||
break;
|
||||
#endif
|
||||
case CUDA_R_32I:
|
||||
case CUDA_C_32I:
|
||||
c.i32[0] = 1;
|
||||
break;
|
||||
case CUDA_R_32U:
|
||||
case CUDA_C_32U:
|
||||
c.u32[0] = 1;
|
||||
break;
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
case CUDA_R_64I:
|
||||
case CUDA_C_64I:
|
||||
c.i64[0] = 1;
|
||||
break;
|
||||
case CUDA_R_64U:
|
||||
case CUDA_C_64U:
|
||||
c.u64[0] = 1;
|
||||
break;
|
||||
#endif
|
||||
#if JAX_CUDA_11080
|
||||
case CUDA_R_8F_E4M3:
|
||||
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E4M3);
|
||||
break;
|
||||
case CUDA_R_8F_E5M2:
|
||||
c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E5M2);
|
||||
break;
|
||||
#endif
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
case CUDA_R_16BF:
|
||||
case CUDA_C_16BF:
|
||||
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
|
||||
break;
|
||||
#endif
|
||||
#endif // JAX_GPU_CUDA
|
||||
// TODO(rocm): add more data types if new rocm supports them.
|
||||
|
||||
// TODO(jakevdp): 16F/16BF here might break on big endian platforms.
|
||||
case HIP_R_16F:
|
||||
case HIP_C_16F:
|
||||
case GPU_R_16F:
|
||||
case GPU_C_16F:
|
||||
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
|
||||
break;
|
||||
case HIP_R_32F:
|
||||
case HIP_C_32F:
|
||||
case GPU_R_32F:
|
||||
case GPU_C_32F:
|
||||
c.f32[0] = 1.0;
|
||||
break;
|
||||
case HIP_R_64F:
|
||||
case HIP_C_64F:
|
||||
case GPU_R_64F:
|
||||
case GPU_C_64F:
|
||||
c.f64[0] = 1.0;
|
||||
break;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
static absl::Status CsrToDense_(hipStream_t stream, void** buffers,
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
static absl::Status CsrToDense_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -88,28 +154,28 @@ static absl::Status CsrToDense_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[2],
|
||||
/*csrColInd=*/buffers[1],
|
||||
/*csrValues=*/buffers[0], d.index_type, d.index_type,
|
||||
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
gpusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -120,7 +186,7 @@ void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// CsrFromDense: Convert dense matrix to CSR matrix
|
||||
|
||||
static absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
|
||||
static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -129,29 +195,29 @@ static absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
hipsparseDnMatDescr_t mat_a = 0;
|
||||
hipsparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
gpusparseDnMatDescr_t mat_a = 0;
|
||||
gpusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*csrRowOffsets=*/buffers[3],
|
||||
/*csrColInd=*/buffers[2],
|
||||
/*csrValues=*/buffers[1], d.index_type, d.index_type,
|
||||
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -162,7 +228,7 @@ void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// CsrMatvec: Product of CSR matrix and dense vector.
|
||||
|
||||
static absl::Status CsrMatvec_(hipStream_t stream, void** buffers,
|
||||
static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -178,38 +244,37 @@ static absl::Status CsrMatvec_(hipStream_t stream, void** buffers,
|
||||
void* ybuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(rocm): check the following statement for rocm
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
HipConst alpha = HipOne(d.y.type);
|
||||
HipConst beta = HipZero(d.y.type);
|
||||
SparseConst alpha = ConstOne(d.y.type);
|
||||
SparseConst beta = ConstZero(d.y.type);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnVecDescr_t vec_x = 0;
|
||||
hipsparseDnVecDescr_t vec_y = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnVecDescr_t vec_x = 0;
|
||||
gpusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind,
|
||||
csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO,
|
||||
csr_values, d.A.index_type, d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO,
|
||||
d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -220,7 +285,7 @@ void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// CsrMatmat: Product of CSR matrix and dense matrix.
|
||||
|
||||
static absl::Status CsrMatmat_(hipStream_t stream, void** buffers,
|
||||
static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -240,34 +305,34 @@ static absl::Status CsrMatmat_(hipStream_t stream, void** buffers,
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
HipConst alpha = HipOne(d.C.type);
|
||||
HipConst beta = HipZero(d.C.type);
|
||||
SparseConst alpha = ConstOne(d.C.type);
|
||||
SparseConst beta = ConstZero(d.C.type);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
hipsparseDnMatDescr_t mat_c = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind,
|
||||
csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO,
|
||||
csr_values, d.A.index_type, d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO,
|
||||
d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -278,7 +343,7 @@ void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
|
||||
static absl::Status CooToDense_(hipStream_t stream, void** buffers,
|
||||
static absl::Status CooToDense_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -287,28 +352,28 @@ static absl::Status CooToDense_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
|
||||
gpusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[1],
|
||||
/*cooColInd=*/buffers[2],
|
||||
/*cooValues=*/buffers[0], d.index_type,
|
||||
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_b, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
gpusparseSparseToDense(handle.get(), mat_a, mat_b,
|
||||
GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -319,7 +384,7 @@ void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// CooFromDense: Convert dense matrix to COO matrix
|
||||
|
||||
static absl::Status CooFromDense_(hipStream_t stream, void** buffers,
|
||||
static absl::Status CooFromDense_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -328,29 +393,29 @@ static absl::Status CooFromDense_(hipStream_t stream, void** buffers,
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
hipsparseDnMatDescr_t mat_a = 0;
|
||||
hipsparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
gpusparseDnMatDescr_t mat_a = 0;
|
||||
gpusparseSpMatDescr_t mat_b = 0;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_a, d.rows, d.cols,
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
|
||||
gpusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
|
||||
/*cooRowInd=*/buffers[2],
|
||||
/*cooColInd=*/buffers[3],
|
||||
/*cooValues=*/buffers[1], d.index_type,
|
||||
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
GPUSPARSE_INDEX_BASE_ZERO, d.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert(
|
||||
handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
|
||||
buffers[4])));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -361,7 +426,7 @@ void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// CooMatvec: Product of COO matrix and dense vector.
|
||||
|
||||
static absl::Status CooMatvec_(hipStream_t stream, void** buffers,
|
||||
static absl::Status CooMatvec_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -377,37 +442,36 @@ static absl::Status CooMatvec_(hipStream_t stream, void** buffers,
|
||||
void* ybuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(rocm): check the following statement for rocm
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
HipConst alpha = HipOne(d.y.type);
|
||||
HipConst beta = HipZero(d.y.type);
|
||||
SparseConst alpha = ConstOne(d.y.type);
|
||||
SparseConst beta = ConstZero(d.y.type);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnVecDescr_t vec_x = 0;
|
||||
hipsparseDnVecDescr_t vec_y = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnVecDescr_t vec_x = 0;
|
||||
gpusparseDnVecDescr_t vec_y = 0;
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
|
||||
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -418,7 +482,7 @@ void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
// CooMatmat: Product of COO matrix and dense matrix.
|
||||
|
||||
static absl::Status CooMatmat_(hipStream_t stream, void** buffers,
|
||||
static absl::Status CooMatmat_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
@ -434,47 +498,46 @@ static absl::Status CooMatmat_(hipStream_t stream, void** buffers,
|
||||
void* Cbuf = buffers[4];
|
||||
void* buf = buffers[5];
|
||||
|
||||
// TODO(rocm): check the following statement for rocm
|
||||
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
|
||||
// are sufficient for basic matvec operations.
|
||||
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
|
||||
// or else the operation will segfault.
|
||||
HipConst alpha = HipOne(d.C.type);
|
||||
HipConst beta = HipZero(d.C.type);
|
||||
SparseConst alpha = ConstOne(d.C.type);
|
||||
SparseConst beta = ConstZero(d.C.type);
|
||||
|
||||
hipsparseSpMatDescr_t mat_a = 0;
|
||||
hipsparseDnMatDescr_t mat_b = 0;
|
||||
hipsparseDnMatDescr_t mat_c = 0;
|
||||
gpusparseSpMatDescr_t mat_a = 0;
|
||||
gpusparseDnMatDescr_t mat_b = 0;
|
||||
gpusparseDnMatDescr_t mat_c = 0;
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo(
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo(
|
||||
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
|
||||
d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
|
||||
/*batchStride=*/d.A.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
gpusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count,
|
||||
/*batchStride=*/d.A.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_b, d.B.rows, d.B.cols,
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.B.cols, Bbuf, d.B.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
|
||||
/*batchStride=*/d.B.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
|
||||
gpusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count,
|
||||
/*batchStride=*/d.B.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat(
|
||||
&mat_c, d.C.rows, d.C.cols,
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW)));
|
||||
/*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
|
||||
/*batchStride=*/d.C.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
gpusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count,
|
||||
/*batchStride=*/d.C.batch_stride)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
|
||||
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
@ -482,9 +545,10 @@ void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
#endif // if JAX_GPU_HAVE_SPARSE
|
||||
|
||||
template <typename T, typename F>
|
||||
static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
|
||||
static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
@ -513,7 +577,7 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
|
||||
if (X != B) {
|
||||
size_t B_bytes = ldb * n * sizeof(T);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipMemcpyAsync(X, B, B_bytes, hipMemcpyDeviceToDevice, stream)));
|
||||
gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
@ -521,22 +585,23 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<float>(hipsparseSgtsv2, stream, buffers, opaque, opaque_len);
|
||||
auto s = gtsv2<float>(gpusparseSgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = gtsv2<double>(hipsparseDgtsv2, stream, buffers, opaque, opaque_len);
|
||||
auto s = gtsv2<double>(gpusparseDgtsv2, stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_HIPSPARSE_KERNELS_H_
|
||||
#define JAXLIB_HIPSPARSE_KERNELS_H_
|
||||
#ifndef JAXLIB_GPU_SPARSE_KERNELS_H_
|
||||
#define JAXLIB_GPU_SPARSE_KERNELS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
@ -23,23 +23,21 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipsparse.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
|
||||
|
||||
namespace jax {
|
||||
|
||||
using SparseHandlePool = HandlePool<hipsparseHandle_t, hipStream_t>;
|
||||
using SparseHandlePool = HandlePool<gpusparseHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SparseHandlePool::Handle>
|
||||
SparseHandlePool::Borrow(hipStream_t stream);
|
||||
/*static*/ absl::StatusOr<SparseHandlePool::Handle> SparseHandlePool::Borrow(
|
||||
gpuStream_t stream);
|
||||
|
||||
union HipConst {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
union SparseConst {
|
||||
int8_t i8[2];
|
||||
int16_t i16[2];
|
||||
int32_t i32[2];
|
||||
@ -52,37 +50,38 @@ union HipConst {
|
||||
double f64[2];
|
||||
};
|
||||
|
||||
HipConst HipZero(hipDataType type);
|
||||
HipConst HipOne(hipDataType type);
|
||||
SparseConst ConstZero(gpuDataType type);
|
||||
SparseConst ConstOne(gpuDataType type);
|
||||
|
||||
struct SparseMatDescriptor {
|
||||
hipDataType value_type;
|
||||
hipsparseIndexType_t index_type;
|
||||
gpuDataType value_type;
|
||||
gpusparseIndexType_t index_type;
|
||||
int rows, cols, nnz;
|
||||
int batch_count = 1;
|
||||
int batch_stride = 0;
|
||||
};
|
||||
|
||||
struct DenseMatDescriptor {
|
||||
hipDataType type;
|
||||
gpuDataType type;
|
||||
int rows, cols;
|
||||
int batch_count = 1;
|
||||
int batch_stride = 0;
|
||||
};
|
||||
|
||||
struct DenseVecDescriptor {
|
||||
hipDataType type;
|
||||
gpuDataType type;
|
||||
int size;
|
||||
};
|
||||
|
||||
#if JAX_GPU_HAVE_SPARSE
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CsrFromDense: Convert dense matrix to CSR matrix
|
||||
|
||||
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CsrMatvec: Product of CSR matrix and dense vector.
|
||||
@ -90,10 +89,10 @@ void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
struct CsrMatvecDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseVecDescriptor x, y;
|
||||
hipsparseOperation_t op;
|
||||
gpusparseOperation_t op;
|
||||
};
|
||||
|
||||
void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CsrMatmat: Product of CSR matrix and dense matrix.
|
||||
@ -101,20 +100,20 @@ void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
struct CsrMatmatDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseMatDescriptor B, C;
|
||||
hipsparseOperation_t op_A;
|
||||
gpusparseOperation_t op_A;
|
||||
};
|
||||
|
||||
void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooToDense: Convert COO matrix to dense matrix
|
||||
|
||||
void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooFromDense: Convert dense matrix to COO matrix
|
||||
|
||||
void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooMatvec: Product of COO matrix and dense vector.
|
||||
@ -122,10 +121,10 @@ void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
|
||||
struct CooMatvecDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseVecDescriptor x, y;
|
||||
hipsparseOperation_t op;
|
||||
gpusparseOperation_t op;
|
||||
};
|
||||
|
||||
void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// CooMatmat: Product of COO matrix and dense matrix.
|
||||
@ -133,22 +132,24 @@ void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
|
||||
struct CooMatmatDescriptor {
|
||||
SparseMatDescriptor A;
|
||||
DenseMatDescriptor B, C;
|
||||
hipsparseOperation_t op_A;
|
||||
gpusparseOperation_t op_A;
|
||||
};
|
||||
|
||||
void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
#endif // JAX_GPU_HAVE_SPARSE
|
||||
|
||||
struct Gtsv2Descriptor {
|
||||
int m, n, ldb;
|
||||
};
|
||||
|
||||
void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque,
|
||||
void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_HIPSPARSE_KERNELS_H_
|
||||
#endif // JAXLIB_GPU_SPARSE_KERNELS_H_
|
442
jaxlib/gpu/vendor.h
Normal file
442
jaxlib/gpu/vendor.h
Normal file
@ -0,0 +1,442 @@
|
||||
/* Copyright 2022 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header is a shim that manages differences between CUDA and ROCM APIs.
|
||||
// Jaxlib GPU kernels can be compiled for either CUDA or ROCM by defining
|
||||
// JAX_GPU_CUDA or JAX_GPU_HIP respectively.
|
||||
|
||||
#ifndef JAXLIB_GPU_VENDOR_H_
|
||||
#define JAXLIB_GPU_VENDOR_H_
|
||||
|
||||
#if defined(JAX_GPU_CUDA)
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
|
||||
|
||||
// CUDA-11.8 introduces FP8 E4M3/E5M2 types.
|
||||
#define JAX_GPU_HAVE_FP8 (CUDA_VERSION >= 11080)
|
||||
|
||||
#if JAX_GPU_HAVE_FP8
|
||||
#include "third_party/gpus/cuda/include/cuda_fp8.h"
|
||||
#endif
|
||||
|
||||
// cuSPARSE generic APIs are not supported on Windows until 11.0
|
||||
// cusparseIndexType_t is used in very limited scope so manually define will
|
||||
// workaround compiling issue without harm.
|
||||
#if defined(_WIN32) && (CUSPARSE_VERSION < 11000)
|
||||
typedef enum {
|
||||
CUSPARSE_INDEX_16U = 1,
|
||||
CUSPARSE_INDEX_32I = 2,
|
||||
CUSPARSE_INDEX_64I = 3
|
||||
} cusparseIndexType_t;
|
||||
#endif
|
||||
|
||||
#define JAX_GPU_NAMESPACE cuda
|
||||
#define JAX_GPU_PREFIX "cu"
|
||||
|
||||
typedef cuComplex gpuComplex;
|
||||
typedef cuDoubleComplex gpuDoubleComplex;
|
||||
|
||||
typedef cuComplex gpublasComplex;
|
||||
typedef cuDoubleComplex gpublasDoubleComplex;
|
||||
typedef cublasFillMode_t gpusolverFillMode_t;
|
||||
typedef cublasStatus_t gpublasStatus_t;
|
||||
typedef cublasHandle_t gpublasHandle_t;
|
||||
typedef cudaDataType gpuDataType;
|
||||
typedef cudaStream_t gpuStream_t;
|
||||
typedef cudaError_t gpuError_t;
|
||||
typedef cusolverDnHandle_t gpusolverDnHandle_t;
|
||||
typedef cusolverStatus_t gpusolverStatus_t;
|
||||
typedef cusolverEigMode_t gpusolverEigMode_t;
|
||||
typedef syevjInfo gpuSyevjInfo;
|
||||
typedef syevjInfo_t gpuSyevjInfo_t;
|
||||
typedef cusparseIndexType_t gpusparseIndexType_t;
|
||||
typedef cusparseHandle_t gpusparseHandle_t;
|
||||
typedef cusparseOperation_t gpusparseOperation_t;
|
||||
typedef cusparseStatus_t gpusparseStatus_t;
|
||||
typedef cusparseSpMatDescr_t gpusparseSpMatDescr_t;
|
||||
typedef cusparseDnMatDescr_t gpusparseDnMatDescr_t;
|
||||
typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
|
||||
#define GPU_C_16F CUDA_C_16F
|
||||
#define GPU_R_16F CUDA_R_16F
|
||||
#define GPU_C_32F CUDA_C_32F
|
||||
#define GPU_R_32F CUDA_R_32F
|
||||
#define GPU_C_64F CUDA_C_64F
|
||||
#define GPU_R_64F CUDA_R_64F
|
||||
|
||||
#define gpublasCreate cublasCreate
|
||||
#define gpublasSetStream cublasSetStream
|
||||
#define gpublasSgeqrfBatched cublasSgeqrfBatched
|
||||
#define gpublasDgeqrfBatched cublasDgeqrfBatched
|
||||
#define gpublasCgeqrfBatched cublasCgeqrfBatched
|
||||
#define gpublasZgeqrfBatched cublasZgeqrfBatched
|
||||
#define gpublasSgetrfBatched cublasSgetrfBatched
|
||||
#define gpublasDgetrfBatched cublasDgetrfBatched
|
||||
#define gpublasCgetrfBatched cublasCgetrfBatched
|
||||
#define gpublasZgetrfBatched cublasZgetrfBatched
|
||||
|
||||
#define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS
|
||||
|
||||
#define gpusolverDnCreate cusolverDnCreate
|
||||
#define gpusolverDnSetStream cusolverDnSetStream
|
||||
#define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo
|
||||
#define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo
|
||||
#define gpusolverDnSpotrf cusolverDnSpotrf
|
||||
#define gpusolverDnDpotrf cusolverDnDpotrf
|
||||
#define gpusolverDnCpotrf cusolverDnCpotrf
|
||||
#define gpusolverDnZpotrf cusolverDnZpotrf
|
||||
#define gpusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
cusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
|
||||
#define gpusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
cusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
|
||||
#define gpusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
cusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
|
||||
#define gpusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
cusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, info, batch)
|
||||
#define gpusolverDnSpotrf_bufferSize cusolverDnSpotrf_bufferSize
|
||||
#define gpusolverDnDpotrf_bufferSize cusolverDnDpotrf_bufferSize
|
||||
#define gpusolverDnCpotrf_bufferSize cusolverDnCpotrf_bufferSize
|
||||
#define gpusolverDnZpotrf_bufferSize cusolverDnZpotrf_bufferSize
|
||||
#define gpusolverDnSgeqrf cusolverDnSgeqrf
|
||||
#define gpusolverDnDgeqrf cusolverDnDgeqrf
|
||||
#define gpusolverDnCgeqrf cusolverDnCgeqrf
|
||||
#define gpusolverDnZgeqrf cusolverDnZgeqrf
|
||||
#define gpusolverDnSgeqrf_bufferSize cusolverDnSgeqrf_bufferSize
|
||||
#define gpusolverDnDgeqrf_bufferSize cusolverDnDgeqrf_bufferSize
|
||||
#define gpusolverDnCgeqrf_bufferSize cusolverDnCgeqrf_bufferSize
|
||||
#define gpusolverDnZgeqrf_bufferSize cusolverDnZgeqrf_bufferSize
|
||||
#define gpusolverDnSgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
cusolverDnSgetrf(h, m, n, a, lda, work, ipiv, info)
|
||||
#define gpusolverDnDgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
cusolverDnDgetrf(h, m, n, a, lda, work, ipiv, info)
|
||||
#define gpusolverDnCgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
cusolverDnCgetrf(h, m, n, a, lda, work, ipiv, info)
|
||||
#define gpusolverDnZgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
cusolverDnZgetrf(h, m, n, a, lda, work, ipiv, info)
|
||||
#define gpusolverDnSgetrf_bufferSize cusolverDnSgetrf_bufferSize
|
||||
#define gpusolverDnDgetrf_bufferSize cusolverDnDgetrf_bufferSize
|
||||
#define gpusolverDnCgetrf_bufferSize cusolverDnCgetrf_bufferSize
|
||||
#define gpusolverDnZgetrf_bufferSize cusolverDnZgetrf_bufferSize
|
||||
#define gpusolverDnSorgqr cusolverDnSorgqr
|
||||
#define gpusolverDnDorgqr cusolverDnDorgqr
|
||||
#define gpusolverDnCungqr cusolverDnCungqr
|
||||
#define gpusolverDnZungqr cusolverDnZungqr
|
||||
#define gpusolverDnSorgqr_bufferSize cusolverDnSorgqr_bufferSize
|
||||
#define gpusolverDnDorgqr_bufferSize cusolverDnDorgqr_bufferSize
|
||||
#define gpusolverDnCungqr_bufferSize cusolverDnCungqr_bufferSize
|
||||
#define gpusolverDnZungqr_bufferSize cusolverDnZungqr_bufferSize
|
||||
#define gpusolverDnSsyevd cusolverDnSsyevd
|
||||
#define gpusolverDnDsyevd cusolverDnDsyevd
|
||||
#define gpusolverDnCheevd cusolverDnCheevd
|
||||
#define gpusolverDnZheevd cusolverDnZheevd
|
||||
#define gpusolverDnSsyevd_bufferSize cusolverDnSsyevd_bufferSize
|
||||
#define gpusolverDnDsyevd_bufferSize cusolverDnDsyevd_bufferSize
|
||||
#define gpusolverDnCheevd_bufferSize cusolverDnCheevd_bufferSize
|
||||
#define gpusolverDnZheevd_bufferSize cusolverDnZheevd_bufferSize
|
||||
#define gpusolverDnSsyevj cusolverDnSsyevj
|
||||
#define gpusolverDnDsyevj cusolverDnDsyevj
|
||||
#define gpusolverDnCheevj cusolverDnCheevj
|
||||
#define gpusolverDnZheevj cusolverDnZheevj
|
||||
#define gpusolverDnSsyevj_bufferSize cusolverDnSsyevj_bufferSize
|
||||
#define gpusolverDnDsyevj_bufferSize cusolverDnDsyevj_bufferSize
|
||||
#define gpusolverDnCheevj_bufferSize cusolverDnCheevj_bufferSize
|
||||
#define gpusolverDnZheevj_bufferSize cusolverDnZheevj_bufferSize
|
||||
#define gpusolverDnSsyevjBatched cusolverDnSsyevjBatched
|
||||
#define gpusolverDnDsyevjBatched cusolverDnDsyevjBatched
|
||||
#define gpusolverDnCheevjBatched cusolverDnCheevjBatched
|
||||
#define gpusolverDnZheevjBatched cusolverDnZheevjBatched
|
||||
#define gpusolverDnSsyevjBatched_bufferSize cusolverDnSsyevjBatched_bufferSize
|
||||
#define gpusolverDnDsyevjBatched_bufferSize cusolverDnDsyevjBatched_bufferSize
|
||||
#define gpusolverDnCheevjBatched_bufferSize cusolverDnCheevjBatched_bufferSize
|
||||
#define gpusolverDnZheevjBatched_bufferSize cusolverDnZheevjBatched_bufferSize
|
||||
#define gpusolverDnSgesvd cusolverDnSgesvd
|
||||
#define gpusolverDnDgesvd cusolverDnDgesvd
|
||||
#define gpusolverDnCgesvd cusolverDnCgesvd
|
||||
#define gpusolverDnZgesvd cusolverDnZgesvd
|
||||
#define gpusolverDnSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
cusolverDnSgesvd_bufferSize(h, m, n, lwork)
|
||||
#define gpusolverDnDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
cusolverDnDgesvd_bufferSize(h, m, n, lwork)
|
||||
#define gpusolverDnCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
cusolverDnCgesvd_bufferSize(h, m, n, lwork)
|
||||
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
cusolverDnZgesvd_bufferSize(h, m, n, lwork)
|
||||
|
||||
#define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER
|
||||
#define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER
|
||||
#define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR
|
||||
#define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS
|
||||
|
||||
#define gpusparseCooSetStridedBatch cusparseCooSetStridedBatch
|
||||
#define gpusparseCreate cusparseCreate
|
||||
#define gpusparseCreateCoo cusparseCreateCoo
|
||||
#define gpusparseCreateCsr cusparseCreateCsr
|
||||
#define gpusparseCreateDnMat cusparseCreateDnMat
|
||||
#define gpusparseCreateDnVec cusparseCreateDnVec
|
||||
#define gpusparseDenseToSparse_analysis cusparseDenseToSparse_analysis
|
||||
#define gpusparseDenseToSparse_bufferSize cusparseDenseToSparse_bufferSize
|
||||
#define gpusparseDenseToSparse_convert cusparseDenseToSparse_convert
|
||||
#define gpusparseDestroySpMat cusparseDestroySpMat
|
||||
#define gpusparseDestroyDnMat cusparseDestroyDnMat
|
||||
#define gpusparseDestroyDnVec cusparseDestroyDnVec
|
||||
#define gpusparseDnMatSetStridedBatch cusparseDnMatSetStridedBatch
|
||||
#define gpusparseSetStream cusparseSetStream
|
||||
#define gpusparseSparseToDense cusparseSparseToDense
|
||||
#define gpusparseSparseToDense_bufferSize cusparseSparseToDense_bufferSize
|
||||
#define gpusparseSpMM cusparseSpMM
|
||||
#define gpusparseSpMM_bufferSize cusparseSpMM_bufferSize
|
||||
#define gpusparseSpMV cusparseSpMV
|
||||
#define gpusparseSpMV_bufferSize cusparseSpMV_bufferSize
|
||||
#define gpusparseSgtsv2 cusparseSgtsv2
|
||||
#define gpusparseDgtsv2 cusparseDgtsv2
|
||||
#define gpusparseSgtsv2_bufferSizeExt cusparseSgtsv2_bufferSizeExt
|
||||
#define gpusparseDgtsv2_bufferSizeExt cusparseDgtsv2_bufferSizeExt
|
||||
|
||||
#define GPUSPARSE_INDEX_16U CUSPARSE_INDEX_16U
|
||||
#define GPUSPARSE_INDEX_32I CUSPARSE_INDEX_32I
|
||||
#define GPUSPARSE_INDEX_64I CUSPARSE_INDEX_64I
|
||||
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO
|
||||
#define GPUSPARSE_MV_ALG_DEFAULT CUSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
|
||||
#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
|
||||
#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
|
||||
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS
|
||||
|
||||
#define gpuGetLastError cudaGetLastError
|
||||
#define gpuGetErrorString cudaGetErrorString
|
||||
#define gpuMemcpyAsync cudaMemcpyAsync
|
||||
#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
|
||||
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
|
||||
#define gpuStreamSynchronize cudaStreamSynchronize
|
||||
#define gpuSuccess cudaSuccess
|
||||
|
||||
#elif defined(JAX_GPU_HIP)
|
||||
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipblas.h"
|
||||
#include "rocm/include/hipsolver.h"
|
||||
#include "rocm/include/hipsparse.h"
|
||||
|
||||
#define JAX_GPU_NAMESPACE hip
|
||||
#define JAX_GPU_PREFIX "hip"
|
||||
|
||||
#define JAX_GPU_HAVE_SPARSE 1
|
||||
#define JAX_GPU_HAVE_FP8 0
|
||||
|
||||
typedef hipFloatComplex gpuComplex;
|
||||
typedef hipDoubleComplex gpuDoubleComplex;
|
||||
|
||||
typedef hipblasComplex gpublasComplex;
|
||||
typedef hipblasDoubleComplex gpublasDoubleComplex;
|
||||
typedef hipsolverHandle_t gpusolverDnHandle_t;
|
||||
typedef hipblasFillMode_t gpublasFillMode_t;
|
||||
typedef hipsolverFillMode_t gpusolverFillMode_t;
|
||||
typedef hipblasHandle_t gpublasHandle_t;
|
||||
typedef hipblasStatus_t gpublasStatus_t;
|
||||
typedef hipDataType gpuDataType;
|
||||
typedef hipStream_t gpuStream_t;
|
||||
typedef hipError_t gpuError_t;
|
||||
typedef void gpuSyevjInfo;
|
||||
typedef hipsolverSyevjInfo_t gpuSyevjInfo_t;
|
||||
typedef hipsolverEigMode_t gpusolverEigMode_t;
|
||||
typedef hipsolverStatus_t gpusolverStatus_t;
|
||||
typedef hipsparseIndexType_t gpusparseIndexType_t;
|
||||
typedef hipsparseHandle_t gpusparseHandle_t;
|
||||
typedef hipsparseOperation_t gpusparseOperation_t;
|
||||
typedef hipsparseStatus_t gpusparseStatus_t;
|
||||
typedef hipsparseSpMatDescr_t gpusparseSpMatDescr_t;
|
||||
typedef hipsparseDnMatDescr_t gpusparseDnMatDescr_t;
|
||||
typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
|
||||
#define GPU_C_16F HIP_C_16F
|
||||
#define GPU_R_16F HIP_R_16F
|
||||
#define GPU_C_32F HIP_C_32F
|
||||
#define GPU_R_32F HIP_R_32F
|
||||
#define GPU_C_64F HIP_C_64F
|
||||
#define GPU_R_64F HIP_R_64F
|
||||
|
||||
#define gpublasCreate hipblasCreate
|
||||
#define gpublasSetStream hipblasSetStream
|
||||
#define gpublasSgeqrfBatched hipblasSgeqrfBatched
|
||||
#define gpublasDgeqrfBatched hipblasDgeqrfBatched
|
||||
#define gpublasCgeqrfBatched hipblasCgeqrfBatched
|
||||
#define gpublasZgeqrfBatched hipblasZgeqrfBatched
|
||||
#define gpublasSgetrfBatched hipblasSgetrfBatched
|
||||
#define gpublasDgetrfBatched hipblasDgetrfBatched
|
||||
#define gpublasCgetrfBatched hipblasCgetrfBatched
|
||||
#define gpublasZgetrfBatched hipblasZgetrfBatched
|
||||
|
||||
#define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
|
||||
#define gpusolverDnCreate hipsolverCreate
|
||||
#define gpusolverDnSetStream hipsolverSetStream
|
||||
#define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo
|
||||
#define gpusolverDnDestroySyevjInfo hipsolverDestroySyevjInfo
|
||||
#define gpusolverDnSpotrf hipsolverSpotrf
|
||||
#define gpusolverDnDpotrf hipsolverDpotrf
|
||||
#define gpusolverDnCpotrf hipsolverCpotrf
|
||||
#define gpusolverDnZpotrf hipsolverZpotrf
|
||||
#define gpusolverDnSpotrf_bufferSize hipsolverSpotrf_bufferSize
|
||||
#define gpusolverDnDpotrf_bufferSize hipsolverDpotrf_bufferSize
|
||||
#define gpusolverDnCpotrf_bufferSize hipsolverCpotrf_bufferSize
|
||||
#define gpusolverDnZpotrf_bufferSize hipsolverZpotrf_bufferSize
|
||||
#define gpusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
hipsolverSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
|
||||
#define gpusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
hipsolverDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
|
||||
#define gpusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
hipsolverCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
|
||||
#define gpusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \
|
||||
batch) \
|
||||
hipsolverZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch)
|
||||
#define gpusolverDnSgeqrf hipsolverSgeqrf
|
||||
#define gpusolverDnDgeqrf hipsolverDgeqrf
|
||||
#define gpusolverDnCgeqrf hipsolverCgeqrf
|
||||
#define gpusolverDnZgeqrf hipsolverZgeqrf
|
||||
#define gpusolverDnSgeqrf_bufferSize hipsolverSgeqrf_bufferSize
|
||||
#define gpusolverDnDgeqrf_bufferSize hipsolverDgeqrf_bufferSize
|
||||
#define gpusolverDnCgeqrf_bufferSize hipsolverCgeqrf_bufferSize
|
||||
#define gpusolverDnZgeqrf_bufferSize hipsolverZgeqrf_bufferSize
|
||||
#define gpusolverDnSgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
hipsolverSgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
|
||||
#define gpusolverDnDgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
hipsolverDgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
|
||||
#define gpusolverDnCgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
hipsolverCgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
|
||||
#define gpusolverDnZgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \
|
||||
hipsolverZgetrf(h, m, n, a, lda, work, lwork, ipiv, info)
|
||||
#define gpusolverDnSgetrf_bufferSize hipsolverSgetrf_bufferSize
|
||||
#define gpusolverDnDgetrf_bufferSize hipsolverDgetrf_bufferSize
|
||||
#define gpusolverDnCgetrf_bufferSize hipsolverCgetrf_bufferSize
|
||||
#define gpusolverDnZgetrf_bufferSize hipsolverZgetrf_bufferSize
|
||||
#define gpusolverDnSorgqr hipsolverSorgqr
|
||||
#define gpusolverDnDorgqr hipsolverDorgqr
|
||||
#define gpusolverDnCungqr hipsolverCungqr
|
||||
#define gpusolverDnZungqr hipsolverZungqr
|
||||
#define gpusolverDnSorgqr_bufferSize hipsolverSorgqr_bufferSize
|
||||
#define gpusolverDnDorgqr_bufferSize hipsolverDorgqr_bufferSize
|
||||
#define gpusolverDnCungqr_bufferSize hipsolverCungqr_bufferSize
|
||||
#define gpusolverDnZungqr_bufferSize hipsolverZungqr_bufferSize
|
||||
#define gpusolverDnSsyevd hipsolverSsyevd
|
||||
#define gpusolverDnDsyevd hipsolverDsyevd
|
||||
#define gpusolverDnCheevd hipsolverCheevd
|
||||
#define gpusolverDnZheevd hipsolverZheevd
|
||||
#define gpusolverDnSsyevd_bufferSize hipsolverSsyevd_bufferSize
|
||||
#define gpusolverDnDsyevd_bufferSize hipsolverDsyevd_bufferSize
|
||||
#define gpusolverDnCheevd_bufferSize hipsolverCheevd_bufferSize
|
||||
#define gpusolverDnZheevd_bufferSize hipsolverZheevd_bufferSize
|
||||
#define gpusolverDnSsyevj hipsolverSsyevj
|
||||
#define gpusolverDnDsyevj hipsolverDsyevj
|
||||
#define gpusolverDnCheevj hipsolverCheevj
|
||||
#define gpusolverDnZheevj hipsolverZheevj
|
||||
#define gpusolverDnSsyevj_bufferSize hipsolverSsyevj_bufferSize
|
||||
#define gpusolverDnDsyevj_bufferSize hipsolverDsyevj_bufferSize
|
||||
#define gpusolverDnCheevj_bufferSize hipsolverCheevj_bufferSize
|
||||
#define gpusolverDnZheevj_bufferSize hipsolverZheevj_bufferSize
|
||||
#define gpusolverDnSsyevjBatched hipsolverSsyevjBatched
|
||||
#define gpusolverDnDsyevjBatched hipsolverDsyevjBatched
|
||||
#define gpusolverDnCheevjBatched hipsolverCheevjBatched
|
||||
#define gpusolverDnZheevjBatched hipsolverZheevjBatched
|
||||
#define gpusolverDnSsyevjBatched_bufferSize hipsolverSsyevjBatched_bufferSize
|
||||
#define gpusolverDnDsyevjBatched_bufferSize hipsolverDsyevjBatched_bufferSize
|
||||
#define gpusolverDnCheevjBatched_bufferSize hipsolverCheevjBatched_bufferSize
|
||||
#define gpusolverDnZheevjBatched_bufferSize hipsolverZheevjBatched_bufferSize
|
||||
#define gpusolverDnSgesvd hipsolverSgesvd
|
||||
#define gpusolverDnDgesvd hipsolverDgesvd
|
||||
#define gpusolverDnCgesvd hipsolverCgesvd
|
||||
#define gpusolverDnZgesvd hipsolverZgesvd
|
||||
#define gpusolverDnSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
hipsolverSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
|
||||
#define gpusolverDnDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
hipsolverDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
|
||||
#define gpusolverDnCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
hipsolverCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
|
||||
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
hipsolverZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
|
||||
|
||||
#define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER
|
||||
#define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER
|
||||
#define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR
|
||||
#define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS
|
||||
|
||||
#define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch
|
||||
#define gpusparseCreate hipsparseCreate
|
||||
#define gpusparseSetStream hipsparseSetStream
|
||||
#define gpusparseCreateCoo hipsparseCreateCoo
|
||||
#define gpusparseCreateCsr hipsparseCreateCsr
|
||||
#define gpusparseCreateDnMat hipsparseCreateDnMat
|
||||
#define gpusparseCreateDnVec hipsparseCreateDnVec
|
||||
#define gpusparseDenseToSparse_analysis hipsparseDenseToSparse_analysis
|
||||
#define gpusparseDenseToSparse_bufferSize hipsparseDenseToSparse_bufferSize
|
||||
#define gpusparseDenseToSparse_convert hipsparseDenseToSparse_convert
|
||||
#define gpusparseDestroySpMat hipsparseDestroySpMat
|
||||
#define gpusparseDestroyDnMat hipsparseDestroyDnMat
|
||||
#define gpusparseDestroyDnVec hipsparseDestroyDnVec
|
||||
#define gpusparseDnMatSetStridedBatch hipsparseDnMatSetStridedBatch
|
||||
#define gpusparseSparseToDense hipsparseSparseToDense
|
||||
#define gpusparseSparseToDense_bufferSize hipsparseSparseToDense_bufferSize
|
||||
#define gpusparseSpMM hipsparseSpMM
|
||||
#define gpusparseSpMM_bufferSize hipsparseSpMM_bufferSize
|
||||
#define gpusparseSpMV hipsparseSpMV
|
||||
#define gpusparseSpMV_bufferSize hipsparseSpMV_bufferSize
|
||||
#define gpusparseSgtsv2 hipsparseSgtsv2
|
||||
#define gpusparseDgtsv2 hipsparseDgtsv2
|
||||
#define gpusparseSgtsv2_bufferSizeExt hipsparseSgtsv2_bufferSizeExt
|
||||
#define gpusparseDgtsv2_bufferSizeExt hipsparseDgtsv2_bufferSizeExt
|
||||
|
||||
#define GPUSPARSE_INDEX_16U HIPSPARSE_INDEX_16U
|
||||
#define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I
|
||||
#define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I
|
||||
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_MV_ALG_DEFAULT HIPSPARSE_MV_ALG_DEFAULT
|
||||
#define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO
|
||||
#define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE
|
||||
#define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE
|
||||
#define GPUSPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW
|
||||
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
|
||||
#define GPUSPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT
|
||||
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
|
||||
|
||||
#define gpuGetLastError hipGetLastError
|
||||
#define gpuGetErrorString hipGetErrorString
|
||||
#define gpuMemcpyAsync hipMemcpyAsync
|
||||
#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define gpuStreamSynchronize hipStreamSynchronize
|
||||
#define gpuSuccess hipSuccess
|
||||
|
||||
#else // defined(GPU vendor)
|
||||
#error "Either JAX_GPU_CUDA or JAX_GPU_HIP must be defined"
|
||||
#endif // defined(GPU vendor)
|
||||
|
||||
#endif // JAXLIB_GPU_VENDOR_H_
|
@ -23,14 +23,14 @@ from .mhlo_helpers import custom_call
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from .cuda import _cuda_linalg
|
||||
from .cuda import _linalg as _cuda_linalg
|
||||
for _name, _value in _cuda_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cuda_linalg = None
|
||||
|
||||
try:
|
||||
from .rocm import _hip_linalg
|
||||
from .rocm import _linalg as _hip_linalg
|
||||
for _name, _value in _hip_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
@ -65,7 +65,7 @@ def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_
|
||||
operand_layouts=[pivots_layout],
|
||||
result_layouts=[permutations_layout])
|
||||
|
||||
cuda_lu_pivots_to_permutation = partial(
|
||||
_lu_pivots_to_permutation_mhlo, "cuda", _cuda_linalg)
|
||||
cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_mhlo, "cu",
|
||||
_cuda_linalg)
|
||||
hip_lu_pivots_to_permutation = partial(
|
||||
_lu_pivots_to_permutation_mhlo, "hip", _hip_linalg)
|
||||
|
@ -25,14 +25,14 @@ from jaxlib import xla_client
|
||||
from .mhlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _cuda_prng
|
||||
from .cuda import _prng as _cuda_prng
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cuda_prng = None
|
||||
|
||||
try:
|
||||
from .rocm import _hip_prng
|
||||
from .rocm import _prng as _hip_prng
|
||||
for _name, _value in _hip_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
@ -64,5 +64,6 @@ def _threefry2x32_lowering(prng, platform, keys, data):
|
||||
operand_layouts=[layout] * 4,
|
||||
result_layouts=[layout] * 2)
|
||||
|
||||
cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cuda")
|
||||
|
||||
cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu")
|
||||
rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip")
|
||||
|
@ -27,14 +27,14 @@ from jaxlib import xla_client
|
||||
from .mhlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _cublas
|
||||
from .cuda import _blas as _cublas
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
_cublas = None
|
||||
|
||||
try:
|
||||
from .cuda import _cusolver
|
||||
from .cuda import _solver as _cusolver
|
||||
for _name, _value in _cusolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
@ -42,14 +42,14 @@ except ImportError:
|
||||
|
||||
|
||||
try:
|
||||
from .rocm import _hipblas
|
||||
from .rocm import _blas as _hipblas
|
||||
for _name, _value in _hipblas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
_hipblas = None
|
||||
|
||||
try:
|
||||
from .rocm import _hipsolver
|
||||
from .rocm import _solver as _hipsolver
|
||||
for _name, _value in _hipsolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
|
@ -26,7 +26,7 @@ from jaxlib import xla_client
|
||||
from .mhlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _cusparse
|
||||
from .cuda import _sparse as _cusparse
|
||||
except ImportError:
|
||||
_cusparse = None
|
||||
else:
|
||||
@ -34,7 +34,7 @@ else:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
|
||||
try:
|
||||
from .rocm import _hipsparse
|
||||
from .rocm import _sparse as _hipsparse
|
||||
except ImportError:
|
||||
_hipsparse = None
|
||||
else:
|
||||
@ -42,8 +42,8 @@ else:
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
|
||||
|
||||
cuda_is_supported : bool = _cusparse and _cusparse.cusparse_supported
|
||||
rocm_is_supported : bool = _hipsparse and _hipsparse.hipsparse_supported
|
||||
cuda_is_supported : bool = _cusparse and _cusparse.sparse_supported
|
||||
rocm_is_supported : bool = _hipsparse and _hipsparse.sparse_supported
|
||||
|
||||
|
||||
def _validate_csr_mhlo(data, indices, indptr, shape):
|
||||
|
@ -25,16 +25,27 @@ licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
|
||||
cc_library(
|
||||
name = "hip_vendor",
|
||||
hdrs = [
|
||||
"//jaxlib/gpu:vendor.h",
|
||||
],
|
||||
defines = ["JAX_GPU_HIP=1"],
|
||||
deps = [
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hip_gpu_kernel_helpers",
|
||||
srcs = if_rocm_is_configured(["hip_gpu_kernel_helpers.cc"]),
|
||||
hdrs = if_rocm_is_configured(["hip_gpu_kernel_helpers.h"]),
|
||||
srcs = if_rocm_is_configured(["//jaxlib/gpu:gpu_kernel_helpers.cc"]),
|
||||
hdrs = if_rocm_is_configured(["//jaxlib/gpu:gpu_kernel_helpers.h"]),
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
":hip_vendor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
@ -46,11 +57,12 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "hipblas_kernels",
|
||||
srcs = ["hipblas_kernels.cc"],
|
||||
hdrs = ["hipblas_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:blas_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:blas_kernels.h"],
|
||||
deps = [
|
||||
"//jaxlib:handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -68,15 +80,16 @@ cc_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_hipblas",
|
||||
srcs = ["hipblas.cc"],
|
||||
name = "_blas",
|
||||
srcs = ["//jaxlib/gpu:blas.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_hipblas",
|
||||
module_name = "_blas",
|
||||
deps = [
|
||||
":hip_vendor",
|
||||
":hipblas_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -89,11 +102,12 @@ pybind_extension(
|
||||
|
||||
cc_library(
|
||||
name = "hipsolver_kernels",
|
||||
srcs = ["hipsolver_kernels.cc"],
|
||||
hdrs = ["hipsolver_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
|
||||
deps = [
|
||||
"//jaxlib:handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
@ -105,16 +119,17 @@ cc_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_hipsolver",
|
||||
srcs = ["hipsolver.cc"],
|
||||
name = "_solver",
|
||||
srcs = ["//jaxlib/gpu:solver.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_hipsolver",
|
||||
module_name = "_solver",
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
":hipsolver_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -127,11 +142,12 @@ pybind_extension(
|
||||
|
||||
cc_library(
|
||||
name = "hipsparse_kernels",
|
||||
srcs = ["hipsparse_kernels.cc"],
|
||||
hdrs = ["hipsparse_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:sparse_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:sparse_kernels.h"],
|
||||
deps = [
|
||||
"//jaxlib:handle_pool",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
@ -143,16 +159,17 @@ cc_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_hipsparse",
|
||||
srcs = ["hipsparse.cc"],
|
||||
name = "_sparse",
|
||||
srcs = ["//jaxlib/gpu:sparse.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_hipsparse",
|
||||
module_name = "_sparse",
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
":hipsparse_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -172,13 +189,12 @@ pybind_extension(
|
||||
|
||||
cc_library(
|
||||
name = "hip_lu_pivot_kernels",
|
||||
srcs = [
|
||||
"hip_lu_pivot_kernels.cc",
|
||||
],
|
||||
hdrs = ["hip_lu_pivot_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:lu_pivot_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_lu_pivot_kernels_impl",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
@ -187,12 +203,11 @@ cc_library(
|
||||
|
||||
rocm_library(
|
||||
name = "hip_lu_pivot_kernels_impl",
|
||||
srcs = [
|
||||
"hip_lu_pivot_kernels.hip.cc",
|
||||
],
|
||||
hdrs = ["hip_lu_pivot_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:lu_pivot_kernels.cu.cc"],
|
||||
hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
@ -200,18 +215,19 @@ rocm_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_hip_linalg",
|
||||
srcs = ["hip_linalg.cc"],
|
||||
name = "_linalg",
|
||||
srcs = ["//jaxlib/gpu:linalg.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_hip_linalg",
|
||||
module_name = "_linalg",
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_lu_pivot_kernels",
|
||||
":hip_lu_pivot_kernels_impl",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@pybind11",
|
||||
@ -220,13 +236,12 @@ pybind_extension(
|
||||
|
||||
cc_library(
|
||||
name = "hip_prng_kernels",
|
||||
srcs = [
|
||||
"hip_prng_kernels.cc",
|
||||
],
|
||||
hdrs = ["hip_prng_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:prng_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_prng_kernels_impl",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
@ -235,12 +250,11 @@ cc_library(
|
||||
|
||||
rocm_library(
|
||||
name = "hip_prng_kernels_impl",
|
||||
srcs = [
|
||||
"hip_prng_kernels.hip.cc",
|
||||
],
|
||||
hdrs = ["hip_prng_kernels.h"],
|
||||
srcs = ["//jaxlib/gpu:prng_kernels.cu.cc"],
|
||||
hdrs = ["//jaxlib/gpu:prng_kernels.h"],
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
@ -248,17 +262,18 @@ rocm_library(
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_hip_prng",
|
||||
srcs = ["hip_prng.cc"],
|
||||
name = "_prng",
|
||||
srcs = ["//jaxlib/gpu:prng.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_hip_prng",
|
||||
module_name = "_prng",
|
||||
deps = [
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_prng_kernels",
|
||||
":hip_vendor",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@pybind11",
|
||||
@ -268,11 +283,10 @@ pybind_extension(
|
||||
py_library(
|
||||
name = "rocm_gpu_support",
|
||||
deps = [
|
||||
":_hip_linalg",
|
||||
":_hip_prng",
|
||||
":_hipblas",
|
||||
":_hipsolver",
|
||||
":_hipsparse",
|
||||
":_blas",
|
||||
":_linalg",
|
||||
":_prng",
|
||||
":_solver",
|
||||
":_sparse",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,66 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_HIP_GPU_KERNEL_HELPERS_H_
|
||||
#define JAXLIB_HIP_GPU_KERNEL_HELPERS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipblas.h"
|
||||
#include "rocm/include/hipsolver.h"
|
||||
#include "rocm/include/hipsparse.h"
|
||||
|
||||
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
|
||||
|
||||
#define JAX_THROW_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto s___ = (expr); \
|
||||
if (!s___.ok()) \
|
||||
throw std::runtime_error(std::string(s___.message())); \
|
||||
}
|
||||
|
||||
#define JAX_RETURN_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto s___ = (expr); \
|
||||
if (!s___.ok()) \
|
||||
return s___; \
|
||||
}
|
||||
|
||||
namespace jax {
|
||||
|
||||
// Used via JAX_AS_STATUS(expr) macro.
|
||||
absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line,
|
||||
const char* expr);
|
||||
absl::Status AsStatus(hipsolverStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(hipsparseStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(hipblasStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
|
||||
// Builds an array of pointers to each array in a batch, in device memory.
|
||||
// Caution: the return value must be kept alive (e.g., via a stream
|
||||
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
|
||||
// completes.
|
||||
absl::StatusOr<std::unique_ptr<void*[]>>
|
||||
MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_HIP_GPU_KERNEL_HELPERS_H_
|
@ -1,51 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/rocm/hip_lu_pivot_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
std::string
|
||||
BuildHipLuPivotsToPermutationDescriptor(std::int64_t batch_size,
|
||||
std::int32_t pivot_size,
|
||||
std::int32_t permutation_size) {
|
||||
return PackDescriptorAsString(LuPivotsToPermutationDescriptor{
|
||||
batch_size, pivot_size, permutation_size});
|
||||
}
|
||||
|
||||
pybind11::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["hip_lu_pivots_to_permutation"] =
|
||||
EncapsulateFunction(HipLuPivotsToPermutation);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_hip_linalg, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("lu_pivots_to_permutation_descriptor",
|
||||
[](std::int64_t batch_size, std::int32_t pivot_size,
|
||||
std::int32_t permutation_size) {
|
||||
std::string result = BuildHipLuPivotsToPermutationDescriptor(
|
||||
batch_size, pivot_size, permutation_size);
|
||||
return pybind11::bytes(result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
@ -1,43 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hip_prng_kernels.h"
|
||||
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
std::string BuildHipThreeFry2x32Descriptor(std::int64_t n) {
|
||||
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
|
||||
}
|
||||
pybind11::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["hip_threefry2x32"] = EncapsulateFunction(HipThreeFry2x32);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_hip_prng, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("threefry2x32_descriptor", [](std::int64_t n) {
|
||||
std::string result = BuildHipThreeFry2x32Descriptor(n);
|
||||
return pybind11::bytes(result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
@ -1,47 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hip_prng_kernels.h"
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
absl::Status HipThreeFry2x32_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
LaunchThreeFry2x32Kernel(stream, buffers, **s);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = HipThreeFry2x32_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
std::string_view message = s.message();
|
||||
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -1,39 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_HIP_PRNG_KERNELS_H_
|
||||
#define JAXLIB_HIP_PRNG_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
struct ThreeFry2x32Descriptor {
|
||||
std::int64_t n;
|
||||
};
|
||||
|
||||
void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor);
|
||||
|
||||
void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_HIP_PRNG_KERNELS_H_
|
@ -1,116 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hip_prng_kernels.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
__global__ void
|
||||
ThreeFry2x32Kernel(const std::uint32_t* key0, const std::uint32_t* key1,
|
||||
const std::uint32_t* data0, const std::uint32_t* data1,
|
||||
std::uint32_t* out0, std::uint32_t* out1, std::int64_t n) {
|
||||
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n;
|
||||
idx += blockDim.x * gridDim.x) {
|
||||
// Rotation distances specified by the Threefry2x32 algorithm.
|
||||
std::uint32_t rotations[8] = {13, 15, 26, 6, 17, 29, 16, 24};
|
||||
std::uint32_t x[2];
|
||||
std::uint32_t ks[3];
|
||||
|
||||
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
|
||||
ks[2] = 0x1BD11BDA;
|
||||
|
||||
ks[0] = key0[idx];
|
||||
x[0] = data0[idx];
|
||||
ks[2] = ks[2] ^ key0[idx];
|
||||
|
||||
ks[1] = key1[idx];
|
||||
x[1] = data1[idx];
|
||||
ks[2] = ks[2] ^ key1[idx];
|
||||
|
||||
auto rotate_left = [](std::uint32_t v, std::uint32_t distance) {
|
||||
return (v << distance) | (v >> (32 - distance));
|
||||
};
|
||||
|
||||
// Performs a single round of the Threefry2x32 algorithm, with a rotation
|
||||
// amount 'rotation'.
|
||||
auto round = [&](std::uint32_t* v, std::uint32_t rotation) {
|
||||
v[0] += v[1];
|
||||
v[1] = rotate_left(v[1], rotation);
|
||||
v[1] ^= v[0];
|
||||
};
|
||||
|
||||
// There are no known statistical flaws with 13 rounds of Threefry2x32.
|
||||
// We are conservative and use 20 rounds.
|
||||
x[0] = x[0] + ks[0];
|
||||
x[1] = x[1] + ks[1];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[1];
|
||||
x[1] = x[1] + ks[2] + 1u;
|
||||
for (int i = 4; i < 8; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[2];
|
||||
x[1] = x[1] + ks[0] + 2u;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[0];
|
||||
x[1] = x[1] + ks[1] + 3u;
|
||||
for (int i = 4; i < 8; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
x[0] = x[0] + ks[1];
|
||||
x[1] = x[1] + ks[2] + 4u;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
round(x, rotations[i]);
|
||||
}
|
||||
|
||||
out0[idx] = x[0] + ks[2];
|
||||
out1[idx] = x[1] + ks[0] + 5u;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers,
|
||||
ThreeFry2x32Descriptor descriptor) {
|
||||
std::array<const std::uint32_t*, 2> keys;
|
||||
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
|
||||
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
|
||||
std::array<const std::uint32_t*, 2> data;
|
||||
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
|
||||
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
|
||||
std::array<std::uint32_t*, 2> out;
|
||||
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
|
||||
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
|
||||
const int block_dim = 128;
|
||||
const std::int64_t grid_dim =
|
||||
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
|
||||
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
|
||||
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
|
||||
out[1], descriptor.n);
|
||||
}
|
||||
|
||||
} // namespace jax
|
@ -1,57 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_HIPBLAS_KERNELS_H_
|
||||
#define JAXLIB_HIPBLAS_KERNELS_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipblas.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
// Set of types known to Hipsolver.
|
||||
enum class HipblasType {
|
||||
F32,
|
||||
F64,
|
||||
C64,
|
||||
C128,
|
||||
};
|
||||
|
||||
// Batched LU decomposition: getrfbatched
|
||||
|
||||
struct GetrfBatchedDescriptor {
|
||||
HipblasType type;
|
||||
int batch, n;
|
||||
};
|
||||
|
||||
void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Batched QR decomposition: geqrfbatched
|
||||
|
||||
struct GeqrfBatchedDescriptor {
|
||||
HipblasType type;
|
||||
int batch, m, n;
|
||||
};
|
||||
|
||||
void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_HIPBLAS_KERNELS_H_
|
@ -1,435 +0,0 @@
|
||||
/* Copyright 2019 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/rocm/hipsolver_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipsolver.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
namespace py = pybind11;
|
||||
|
||||
// Converts a NumPy dtype to a Type.
|
||||
HipsolverType DtypeToHipsolverType(const py::dtype& np_type) {
|
||||
static auto* types =
|
||||
new absl::flat_hash_map<std::pair<char, int>, HipsolverType>({
|
||||
{{'f', 4}, HipsolverType::F32},
|
||||
{{'f', 8}, HipsolverType::F64},
|
||||
{{'c', 8}, HipsolverType::C64},
|
||||
{{'c', 16}, HipsolverType::C128},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// potrf: Cholesky decomposition
|
||||
|
||||
// Returns the workspace size and a descriptor for a potrf operation.
|
||||
std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
std::int64_t workspace_size;
|
||||
hipsolverFillMode_t uplo =
|
||||
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
|
||||
if (b == 1) {
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverSpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(float);
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverDpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(double);
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverCpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(hipComplex);
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverZpotrf_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork)));
|
||||
workspace_size = lwork * sizeof(hipDoubleComplex);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// TODO(rocm): when cuda and hip had same API for batched potrf, remove this
|
||||
// batched potrf has different API compared to CUDA. In hip we still need to create the workspace and additional space to copy the batch array pointers
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverSpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size = (lwork * sizeof(float)) + (b * sizeof(float*));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverDpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size = (lwork * sizeof(double)) + (b * sizeof(double*));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverCpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size = (lwork * sizeof(hipComplex)) + (b * sizeof(hipComplex*));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverZpotrfBatched_bufferSize(handle.get(), uplo, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/n, &lwork, b)));
|
||||
workspace_size = (lwork * sizeof(hipDoubleComplex)) + (b * sizeof(hipDoubleComplex*));
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
return {workspace_size,
|
||||
PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})};
|
||||
}
|
||||
|
||||
// getrf: LU decomposition
|
||||
|
||||
// Returns the workspace size and a descriptor for a getrf operation.
|
||||
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverSgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverDgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverCgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverZgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})};
|
||||
}
|
||||
|
||||
// geqrf: QR decomposition
|
||||
|
||||
// Returns the workspace size and a descriptor for a geqrf operation.
|
||||
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverSgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverDgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverCgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverZgeqrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})};
|
||||
}
|
||||
|
||||
// orgqr/ungqr: apply elementary Householder transformations
|
||||
|
||||
// Returns the workspace size and a descriptor for a geqrf operation.
|
||||
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n, int k) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverSorgqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverDorgqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverCungqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(
|
||||
JAX_AS_STATUS(hipsolverZungqr_bufferSize(handle.get(), m, n, k,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m,
|
||||
/*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})};
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
|
||||
|
||||
// Returns the workspace size and a descriptor for a syevd operation.
|
||||
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
|
||||
hipsolverFillMode_t uplo =
|
||||
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
|
||||
/*lda=*/n, /*W=*/nullptr, &lwork)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
|
||||
/*lda=*/n, /*W=*/nullptr, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverCheevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
|
||||
/*lda=*/n, /*W=*/nullptr, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})};
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
// Returns the workspace size and a descriptor for a syevj_batched operation.
|
||||
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
|
||||
bool lower, int batch, int n) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
hipsolverSyevjInfo_t params;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
|
||||
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
|
||||
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
|
||||
hipsolverFillMode_t uplo =
|
||||
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
|
||||
if (batch == 1) {
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})};
|
||||
}
|
||||
|
||||
// Singular value decomposition using QR algorithm: gesvd
|
||||
|
||||
// Returns the workspace size and a descriptor for a gesvd operation.
|
||||
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n, bool compute_uv,
|
||||
bool full_matrices) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
signed char jobu, jobvt;
|
||||
if (compute_uv) {
|
||||
if (full_matrices) {
|
||||
jobu = jobvt = 'A';
|
||||
} else {
|
||||
jobu = jobvt = 'S';
|
||||
}
|
||||
} else {
|
||||
jobu = jobvt = 'N';
|
||||
}
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverCgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverZgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork,
|
||||
PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})};
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["hipsolver_potrf"] = EncapsulateFunction(Potrf);
|
||||
dict["hipsolver_getrf"] = EncapsulateFunction(Getrf);
|
||||
dict["hipsolver_geqrf"] = EncapsulateFunction(Geqrf);
|
||||
dict["hipsolver_orgqr"] = EncapsulateFunction(Orgqr);
|
||||
dict["hipsolver_syevd"] = EncapsulateFunction(Syevd);
|
||||
dict["hipsolver_syevj"] = EncapsulateFunction(Syevj);
|
||||
dict["hipsolver_gesvd"] = EncapsulateFunction(Gesvd);
|
||||
// dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); not supported by
|
||||
// ROCm yet
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_hipsolver, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_potrf_descriptor", &BuildPotrfDescriptor);
|
||||
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
|
||||
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
|
||||
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
|
||||
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
|
||||
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
|
||||
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
|
||||
// m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); not supported by
|
||||
// ROCm yet
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
@ -1,721 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/rocm/hipsolver_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/rocm/hip_gpu_kernel_helpers.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipsolver.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<SolverHandlePool::Handle>
|
||||
SolverHandlePool::Borrow(hipStream_t stream) {
|
||||
SolverHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
hipsolverHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
static int SizeOfHipsolverType(HipsolverType type) {
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
return sizeof(float);
|
||||
case HipsolverType::F64:
|
||||
return sizeof(double);
|
||||
case HipsolverType::C64:
|
||||
return sizeof(hipFloatComplex);
|
||||
case HipsolverType::C128:
|
||||
return sizeof(hipDoubleComplex);
|
||||
}
|
||||
}
|
||||
|
||||
// potrf: Cholesky decomposition
|
||||
|
||||
static absl::Status Potrf_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const PotrfDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * d.batch * d.n * d.n,
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[2]);
|
||||
void* workspace = buffers[3];
|
||||
if (d.batch == 1) {
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSpotrf(handle.get(), d.uplo, d.n, a, d.n,
|
||||
static_cast<float*>(workspace), d.lwork, info)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDpotrf(handle.get(), d.uplo, d.n, a, d.n,
|
||||
static_cast<double*>(workspace), d.lwork, info)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrf(
|
||||
handle.get(), d.uplo, d.n, a, d.n,
|
||||
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrf(
|
||||
handle.get(), d.uplo, d.n, a, d.n,
|
||||
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto buffer_ptrs_host =
|
||||
MakeBatchPointers(stream, buffers[1], workspace, d.batch,
|
||||
SizeOfHipsolverType(d.type) * d.n * d.n);
|
||||
JAX_RETURN_IF_ERROR(buffer_ptrs_host.status());
|
||||
// Make sure that accesses to buffer_ptrs_host complete before we delete it.
|
||||
// TODO(phawkins): avoid synchronization here.
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<float**>(workspace), d.n,
|
||||
reinterpret_cast<float*>(static_cast<float**>(workspace) + d.batch),
|
||||
d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<double**>(workspace), d.n,
|
||||
reinterpret_cast<double*>(static_cast<double**>(workspace) + d.batch),
|
||||
d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<hipFloatComplex**>(workspace), d.n,
|
||||
reinterpret_cast<hipFloatComplex*>(static_cast<hipFloatComplex**>(workspace) +
|
||||
d.batch), d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<hipDoubleComplex**>(workspace), d.n,
|
||||
reinterpret_cast<hipDoubleComplex*>(static_cast<hipDoubleComplex**>(workspace) +
|
||||
d.batch), d.lwork, info, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Potrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// getrf: LU decomposition
|
||||
|
||||
static absl::Status Getrf_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GetrfDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* ipiv = static_cast<int*>(buffers[2]);
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* workspace = buffers[4];
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSgetrf(handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<float*>(workspace), d.lwork, ipiv, info)));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDgetrf(handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<double*>(workspace), d.lwork, ipiv, info)));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverCgetrf(handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<hipFloatComplex*>(workspace), d.lwork, ipiv, info)));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgetrf(
|
||||
handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<hipDoubleComplex*>(workspace), d.lwork, ipiv, info)));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Getrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// geqrf: QR decomposition
|
||||
|
||||
static absl::Status Geqrf_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GeqrfDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
|
||||
void* workspace = buffers[4];
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* tau = static_cast<float*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSgeqrf(handle.get(), d.m, d.n, a, d.m, tau,
|
||||
static_cast<float*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* tau = static_cast<double*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDgeqrf(handle.get(), d.m, d.n, a, d.m, tau,
|
||||
static_cast<double*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
hipFloatComplex* tau = static_cast<hipFloatComplex*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgeqrf(
|
||||
handle.get(), d.m, d.n, a, d.m, tau,
|
||||
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
hipDoubleComplex* tau = static_cast<hipDoubleComplex*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgeqrf(
|
||||
handle.get(), d.m, d.n, a, d.m, tau,
|
||||
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// orgqr/ungqr: apply elementary Householder transformations
|
||||
|
||||
static absl::Status Orgqr_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const OrgqrDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[2] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[2], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
|
||||
void* workspace = buffers[4];
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[2]);
|
||||
float* tau = static_cast<float*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau,
|
||||
static_cast<float*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[2]);
|
||||
double* tau = static_cast<double*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau,
|
||||
static_cast<double*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[2]);
|
||||
hipFloatComplex* tau = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCungqr(
|
||||
handle.get(), d.m, d.n, d.k, a, d.m, tau,
|
||||
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[2]);
|
||||
hipDoubleComplex* tau = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZungqr(
|
||||
handle.get(), d.m, d.n, d.k, a, d.m, tau,
|
||||
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
|
||||
a += d.m * d.n;
|
||||
tau += d.k;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
|
||||
|
||||
static absl::Status Syevd_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SyevdDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SyevdDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* work = buffers[4];
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<float*>(work), d.lwork, info)));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<double*>(work), d.lwork, info)));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipFloatComplex*>(work), d.lwork, info)));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipDoubleComplex*>(work), d.lwork, info)));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Syevd_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SyevjDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SyevjDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
hipsolverSyevjInfo_t params;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
|
||||
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
|
||||
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* work = buffers[4];
|
||||
if (d.batch == 1) {
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<float*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<double*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipFloatComplex*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipDoubleComplex*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<float*>(work), d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<double*>(work), d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipFloatComplex*>(work), d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipDoubleComplex*>(work),
|
||||
d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Syevj_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Singular value decomposition using QR algorithm: gesvd
|
||||
|
||||
static absl::Status Gesvd_(hipStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const GesvdDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
void* work = buffers[6];
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* s = static_cast<float*>(buffers[2]);
|
||||
float* u = static_cast<float*>(buffers[3]);
|
||||
float* vt = static_cast<float*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverSgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s,
|
||||
u, d.m, vt, d.n, static_cast<float*>(work), d.lwork,
|
||||
/*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* s = static_cast<double*>(buffers[2]);
|
||||
double* u = static_cast<double*>(buffers[3]);
|
||||
double* vt = static_cast<double*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDgesvd(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
|
||||
static_cast<double*>(work), d.lwork,
|
||||
/*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
float* s = static_cast<float*>(buffers[2]);
|
||||
hipFloatComplex* u = static_cast<hipFloatComplex*>(buffers[3]);
|
||||
hipFloatComplex* vt = static_cast<hipFloatComplex*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgesvd(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
|
||||
static_cast<hipFloatComplex*>(work), d.lwork, /*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
double* s = static_cast<double*>(buffers[2]);
|
||||
hipDoubleComplex* u = static_cast<hipDoubleComplex*>(buffers[3]);
|
||||
hipDoubleComplex* vt = static_cast<hipDoubleComplex*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgesvd(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
|
||||
static_cast<hipDoubleComplex*>(work), d.lwork,
|
||||
/*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(rocm): add Gesvdj_ apis when support from hipsolver is ready
|
||||
} // namespace jax
|
@ -1,122 +0,0 @@
|
||||
/* Copyright 2021 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_HIPSOLVER_KERNELS_H_
|
||||
#define JAXLIB_HIPSOLVER_KERNELS_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "rocm/include/hip/hip_runtime_api.h"
|
||||
#include "rocm/include/hipblas.h"
|
||||
#include "rocm/include/hipsolver.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
using SolverHandlePool = HandlePool<hipsolverHandle_t, hipStream_t>;
|
||||
|
||||
template <>
|
||||
absl::StatusOr<SolverHandlePool::Handle>
|
||||
SolverHandlePool::Borrow(hipStream_t stream);
|
||||
|
||||
// Set of types known to Hipsolver.
|
||||
enum class HipsolverType {
|
||||
F32,
|
||||
F64,
|
||||
C64,
|
||||
C128,
|
||||
};
|
||||
|
||||
// potrf: Cholesky decomposition
|
||||
|
||||
struct PotrfDescriptor {
|
||||
HipsolverType type;
|
||||
hipsolverFillMode_t uplo;
|
||||
std::int64_t batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
// getrf: LU decomposition
|
||||
|
||||
struct GetrfDescriptor {
|
||||
HipsolverType type;
|
||||
int batch, m, n, lwork;
|
||||
};
|
||||
|
||||
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// geqrf: QR decomposition
|
||||
|
||||
struct GeqrfDescriptor {
|
||||
HipsolverType type;
|
||||
int batch, m, n, lwork;
|
||||
};
|
||||
|
||||
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// orgqr/ungqr: apply elementary Householder transformations
|
||||
|
||||
struct OrgqrDescriptor {
|
||||
HipsolverType type;
|
||||
int batch, m, n, k, lwork;
|
||||
};
|
||||
|
||||
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
|
||||
|
||||
struct SyevdDescriptor {
|
||||
HipsolverType type;
|
||||
hipsolverFillMode_t uplo;
|
||||
int batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
struct SyevjDescriptor {
|
||||
HipsolverType type;
|
||||
hipsolverFillMode_t uplo;
|
||||
int batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Singular value decomposition using QR algorithm: gesvd
|
||||
|
||||
struct GesvdDescriptor {
|
||||
HipsolverType type;
|
||||
int batch, m, n;
|
||||
int lwork;
|
||||
signed char jobu, jobvt;
|
||||
};
|
||||
|
||||
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_HIPSOLVER_KERNELS_H_
|
Loading…
x
Reference in New Issue
Block a user