diff --git a/build/build_wheel.py b/build/build_wheel.py index b6a290455..ac98482fd 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -184,25 +184,25 @@ def prepare_wheel(sources_path): copy_file(f"__main__/jaxlib/cpu/_ducc_fft.{pyext}", dst_dir=cpu_dir) cuda_dir = os.path.join(jaxlib_dir, "cuda") - if exists(f"__main__/jaxlib/cuda/_cusolver.{pyext}"): + if exists(f"__main__/jaxlib/cuda/_solver.{pyext}"): libdevice_dir = os.path.join(cuda_dir, "nvvm", "libdevice") os.makedirs(libdevice_dir) copy_file("local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc", dst_dir=libdevice_dir) - copy_file(f"__main__/jaxlib/cuda/_cusolver.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_cublas.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_cuda_linalg.{pyext}", dst_dir=cuda_dir) - copy_file(f"__main__/jaxlib/cuda/_cuda_prng.{pyext}", dst_dir=cuda_dir) + copy_file(f"__main__/jaxlib/cuda/_solver.{pyext}", dst_dir=cuda_dir) + copy_file(f"__main__/jaxlib/cuda/_blas.{pyext}", dst_dir=cuda_dir) + copy_file(f"__main__/jaxlib/cuda/_linalg.{pyext}", dst_dir=cuda_dir) + copy_file(f"__main__/jaxlib/cuda/_prng.{pyext}", dst_dir=cuda_dir) rocm_dir = os.path.join(jaxlib_dir, "rocm") - if exists(f"__main__/jaxlib/rocm/_hipsolver.{pyext}"): + if exists(f"__main__/jaxlib/rocm/_solver.{pyext}"): os.makedirs(rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_hipsolver.{pyext}", dst_dir=rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_hipblas.{pyext}", dst_dir=rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_hip_linalg.{pyext}", dst_dir=rocm_dir) - copy_file(f"__main__/jaxlib/rocm/_hip_prng.{pyext}", dst_dir=rocm_dir) - if exists(f"__main__/jaxlib/cuda/_cusparse.{pyext}"): - copy_file(f"__main__/jaxlib/cuda/_cusparse.{pyext}", dst_dir=cuda_dir) - if exists(f"__main__/jaxlib/rocm/_hipsparse.{pyext}"): - copy_file(f"__main__/jaxlib/rocm/_hipsparse.{pyext}", dst_dir=rocm_dir) + copy_file(f"__main__/jaxlib/rocm/_solver.{pyext}", dst_dir=rocm_dir) + copy_file(f"__main__/jaxlib/rocm/_blas.{pyext}", dst_dir=rocm_dir) + copy_file(f"__main__/jaxlib/rocm/_linalg.{pyext}", dst_dir=rocm_dir) + copy_file(f"__main__/jaxlib/rocm/_prng.{pyext}", dst_dir=rocm_dir) + if exists(f"__main__/jaxlib/cuda/_sparse.{pyext}"): + copy_file(f"__main__/jaxlib/cuda/_sparse.{pyext}", dst_dir=cuda_dir) + if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"): + copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir) mlir_dir = os.path.join(jaxlib_dir, "mlir") diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 37a3f13f6..07bb4b1df 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -24,15 +24,31 @@ licenses(["notice"]) package(default_visibility = ["//:__subpackages__"]) +cc_library( + name = "cuda_vendor", + hdrs = [ + "//jaxlib/gpu:vendor.h", + ], + defines = ["JAX_GPU_CUDA=1"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) + cc_library( name = "cuda_gpu_kernel_helpers", - srcs = ["cuda_gpu_kernel_helpers.cc"], - hdrs = ["cuda_gpu_kernel_helpers.h"], + srcs = [ + "//jaxlib/gpu:gpu_kernel_helpers.cc", + ], + hdrs = [ + "//jaxlib/gpu:gpu_kernel_helpers.h", + ], copts = [ "-fexceptions", ], features = ["-use_header_modules"], deps = [ + ":cuda_vendor", "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusolver_lib", "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cusparse_lib", "@com_google_absl//absl/memory", @@ -47,10 +63,11 @@ cc_library( cc_library( name = "cublas_kernels", - srcs = ["cublas_kernels.cc"], - hdrs = ["cublas_kernels.h"], + srcs = ["//jaxlib/gpu:blas_kernels.cc"], + hdrs = ["//jaxlib/gpu:blas_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", @@ -71,31 +88,32 @@ cc_library( ) pybind_extension( - name = "_cublas", - srcs = ["cublas.cc"], + name = "_blas", + srcs = ["//jaxlib/gpu:blas.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_cublas", + module_name = "_blas", deps = [ ":cublas_kernels", + ":cuda_vendor", "//jaxlib:kernel_pybind11_helpers", "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cublas_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cuda_headers", "@pybind11", ], ) cc_library( name = "cusolver_kernels", - srcs = ["cusolver_kernels.cc"], - hdrs = ["cusolver_kernels.h"], + srcs = ["//jaxlib/gpu:solver_kernels.cc"], + hdrs = ["//jaxlib/gpu:solver_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", @@ -108,16 +126,17 @@ cc_library( ) pybind_extension( - name = "_cusolver", - srcs = ["cusolver.cc"], + name = "_solver", + srcs = ["//jaxlib/gpu:solver.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_cusolver", + module_name = "_solver", deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_vendor", ":cusolver_kernels", "//jaxlib:kernel_pybind11_helpers", "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", @@ -131,10 +150,11 @@ pybind_extension( cc_library( name = "cusparse_kernels", - srcs = ["cusparse_kernels.cc"], - hdrs = ["cusparse_kernels.h"], + srcs = ["//jaxlib/gpu:sparse_kernels.cc"], + hdrs = ["//jaxlib/gpu:sparse_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", @@ -148,16 +168,17 @@ cc_library( ) pybind_extension( - name = "_cusparse", - srcs = ["cusparse.cc"], + name = "_sparse", + srcs = ["//jaxlib/gpu:sparse.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_cusparse", + module_name = "_sparse", deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_vendor", ":cusparse_kernels", "//jaxlib:kernel_pybind11_helpers", "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", @@ -179,12 +200,13 @@ pybind_extension( cc_library( name = "cuda_lu_pivot_kernels", srcs = [ - "cuda_lu_pivot_kernels.cc", + "//jaxlib/gpu:lu_pivot_kernels.cc", ], - hdrs = ["cuda_lu_pivot_kernels.h"], + hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", ":cuda_lu_pivot_kernels_impl", + ":cuda_vendor", "//jaxlib:kernel_helpers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", @@ -194,11 +216,12 @@ cc_library( cuda_library( name = "cuda_lu_pivot_kernels_impl", srcs = [ - "cuda_lu_pivot_kernels.cu.cc", + "//jaxlib/gpu:lu_pivot_kernels.cu.cc", ], - hdrs = ["cuda_lu_pivot_kernels.h"], + hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_vendor", "//jaxlib:kernel_helpers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", @@ -206,18 +229,19 @@ cuda_library( ) pybind_extension( - name = "_cuda_linalg", - srcs = ["cuda_linalg.cc"], + name = "_linalg", + srcs = ["//jaxlib/gpu:linalg.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_cuda_linalg", + module_name = "_linalg", deps = [ ":cuda_gpu_kernel_helpers", ":cuda_lu_pivot_kernels", ":cuda_lu_pivot_kernels_impl", + ":cuda_vendor", "//jaxlib:kernel_pybind11_helpers", "@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub", "@local_config_cuda//cuda:cuda_headers", @@ -228,12 +252,13 @@ pybind_extension( cc_library( name = "cuda_prng_kernels", srcs = [ - "cuda_prng_kernels.cc", + "//jaxlib/gpu:prng_kernels.cc", ], - hdrs = ["cuda_prng_kernels.h"], + hdrs = ["//jaxlib/gpu:prng_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels_impl", + ":cuda_vendor", "//jaxlib:kernel_helpers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", @@ -243,11 +268,12 @@ cc_library( cuda_library( name = "cuda_prng_kernels_impl", srcs = [ - "cuda_prng_kernels.cu.cc", + "//jaxlib/gpu:prng_kernels.cu.cc", ], - hdrs = ["cuda_prng_kernels.h"], + hdrs = ["//jaxlib/gpu:prng_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", + ":cuda_vendor", "//jaxlib:kernel_helpers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@local_config_cuda//cuda:cuda_headers", @@ -255,14 +281,14 @@ cuda_library( ) pybind_extension( - name = "_cuda_prng", - srcs = ["cuda_prng.cc"], + name = "_prng", + srcs = ["//jaxlib/gpu:prng.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_cuda_prng", + module_name = "_prng", deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", @@ -275,12 +301,13 @@ pybind_extension( cc_library( name = "cuda_gpu_kernels", - srcs = ["cuda_gpu_kernels.cc"], + srcs = ["//jaxlib/gpu:gpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ ":cublas_kernels", ":cuda_lu_pivot_kernels", ":cuda_prng_kernels", + ":cuda_vendor", ":cusolver_kernels", ":cusparse_kernels", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", @@ -291,10 +318,10 @@ cc_library( py_library( name = "cuda_gpu_support", deps = [ - ":_cublas", - ":_cuda_linalg", - ":_cuda_prng", - ":_cusolver", - ":_cusparse", + ":_blas", + ":_linalg", + ":_prng", + ":_solver", + ":_sparse", ], ) diff --git a/jaxlib/cuda/cublas.cc b/jaxlib/cuda/cublas.cc deleted file mode 100644 index 6679ddc12..000000000 --- a/jaxlib/cuda/cublas.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "third_party/gpus/cuda/include/cublas_v2.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "jaxlib/cuda/cublas_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" -#include "include/pybind11/numpy.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" - -namespace jax { -namespace { - -namespace py = pybind11; - -// Converts a NumPy dtype to a Type. -CublasType DtypeToCublasType(const py::dtype& np_type) { - static auto* types = - new absl::flat_hash_map, CublasType>({ - {{'f', 4}, CublasType::F32}, - {{'f', 8}, CublasType::F64}, - {{'c', 8}, CublasType::C64}, - {{'c', 16}, CublasType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", py::repr(np_type))); - } - return it->second; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, - int b, int n) { - CublasType type = DtypeToCublasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGeqrfBatchedDescriptor(const py::dtype& dtype, - int b, int m, int n) { - CublasType type = DtypeToCublasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; -} - -py::dict Registrations() { - py::dict dict; - dict["cublas_getrf_batched"] = EncapsulateFunction(GetrfBatched); - dict["cublas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - return dict; -} - -PYBIND11_MODULE(_cublas, m) { - m.def("registrations", &Registrations); - m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); - m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); -} - -} // namespace -} // namespace jax diff --git a/jaxlib/cuda/cublas_kernels.cc b/jaxlib/cuda/cublas_kernels.cc deleted file mode 100644 index d171ac45c..000000000 --- a/jaxlib/cuda/cublas_kernels.cc +++ /dev/null @@ -1,225 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/cuda/cublas_kernels.h" - -#include -#include -#include -#include - -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "third_party/gpus/cuda/include/cublas_v2.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" -#include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { -using BlasHandlePool = HandlePool; - -template <> -/*static*/ absl::StatusOr BlasHandlePool::Borrow( - cudaStream_t stream) { - BlasHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); - cublasHandle_t handle; - if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCreate(&handle))); - } else { - handle = pool->handles_[stream].back(); - pool->handles_[stream].pop_back(); - } - if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSetStream(handle, stream))); - } - return Handle(pool, handle, stream); -} - -namespace { - -// Converts a NumPy dtype to a CublasType. - -int SizeOfCublasType(CublasType type) { - switch (type) { - case CublasType::F32: - return sizeof(float); - case CublasType::F64: - return sizeof(double); - case CublasType::C64: - return sizeof(cuComplex); - case CublasType::C128: - return sizeof(cuDoubleComplex); - } -} - -} // namespace - -// Batched LU decomposition: getrfbatched - -static absl::Status GetrfBatched_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( - buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.n * d.n, - cudaMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch, - SizeOfCublasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream))); - switch (d.type) { - case CublasType::F32: { - float** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasSgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case CublasType::F64: { - double** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasDgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case CublasType::C64: { - cuComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasCgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case CublasType::C128: { - cuDoubleComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cublasZgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Batched QR decomposition: geqrfbatched - -static absl::Status GeqrfBatched_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( - buffers[1], buffers[0], SizeOfCublasType(d.type) * d.batch * d.m * d.n, - cudaMemcpyDeviceToDevice, stream))); - } - - std::vector info(d.batch); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, - SizeOfCublasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - auto tau_ptrs_host = - MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, - SizeOfCublasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(tau_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream))); - switch (d.type) { - case CublasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case CublasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case CublasType::C64: { - cuComplex** a_batch_ptrs = static_cast(buffers[3]); - cuComplex** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case CublasType::C128: { - cuDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - cuDoubleComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - } - auto it = - std::find_if(info.begin(), info.end(), [](int i) { return i != 0; }); - - if (it != info.end()) { - return absl::InvalidArgumentError( - absl::StrFormat("QR decomposition failed with status %d for batch " - "element %d", - *it, std::distance(info.begin(), it))); - } - - return absl::OkStatus(); -} - -void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace jax diff --git a/jaxlib/cuda/cuda_gpu_kernel_helpers.cc b/jaxlib/cuda/cuda_gpu_kernel_helpers.cc deleted file mode 100644 index 5887b6034..000000000 --- a/jaxlib/cuda/cuda_gpu_kernel_helpers.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" - -#include - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" - -namespace jax { -namespace { -std::string ErrorString(cudaError_t error) { return cudaGetErrorString(error); } - -std::string ErrorString(cusparseStatus_t status) { - return cusparseGetErrorString(status); -} - -std::string ErrorString(cusolverStatus_t status) { - switch (status) { - case CUSOLVER_STATUS_SUCCESS: - return "cuSolver success."; - case CUSOLVER_STATUS_NOT_INITIALIZED: - return "cuSolver has not been initialized"; - case CUSOLVER_STATUS_ALLOC_FAILED: - return "cuSolver allocation failed"; - case CUSOLVER_STATUS_INVALID_VALUE: - return "cuSolver invalid value error"; - case CUSOLVER_STATUS_ARCH_MISMATCH: - return "cuSolver architecture mismatch error"; - case CUSOLVER_STATUS_MAPPING_ERROR: - return "cuSolver mapping error"; - case CUSOLVER_STATUS_EXECUTION_FAILED: - return "cuSolver execution failed"; - case CUSOLVER_STATUS_INTERNAL_ERROR: - return "cuSolver internal error"; - case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: - return "cuSolver matrix type not supported error"; - case CUSOLVER_STATUS_NOT_SUPPORTED: - return "cuSolver not supported error"; - case CUSOLVER_STATUS_ZERO_PIVOT: - return "cuSolver zero pivot error"; - case CUSOLVER_STATUS_INVALID_LICENSE: - return "cuSolver invalid license error"; - default: - return absl::StrCat("Unknown cuSolver error: ", status); - } -} - -std::string ErrorString(cublasStatus_t status) { - switch (status) { - case CUBLAS_STATUS_SUCCESS: - return "cuBlas success"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "cuBlas has not been initialized"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "cuBlas allocation failure"; - case CUBLAS_STATUS_INVALID_VALUE: - return "cuBlas invalid value error"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "cuBlas architecture mismatch"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "cuBlas mapping error"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "cuBlas execution failed"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "cuBlas internal error"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "cuBlas not supported error"; - case CUBLAS_STATUS_LICENSE_ERROR: - return "cuBlas license error"; - default: - return "Unknown cuBlas error"; - } -} - -template -std::string ErrorString(T status, const char* file, std::int64_t line, - const char* expr) { - return absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr, - ErrorString(status)); -} -} // namespace - -absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line, - const char* expr) { - if (error != cudaSuccess) - return absl::InternalError(ErrorString(error, file, line, expr)); - return absl::OkStatus(); -} - -absl::Status AsStatus(cusolverStatus_t status, const char* file, - std::int64_t line, const char* expr) { - if (status != CUSOLVER_STATUS_SUCCESS) - return absl::InternalError(ErrorString(status, file, line, expr)); - return absl::OkStatus(); -} - -absl::Status AsStatus(cusparseStatus_t status, const char* file, - std::int64_t line, const char* expr) { - if (status != CUSPARSE_STATUS_SUCCESS) - return absl::InternalError(ErrorString(status, file, line, expr)); - return absl::OkStatus(); -} - -absl::Status AsStatus(cublasStatus_t status, const char* file, - std::int64_t line, const char* expr) { - if (status != CUBLAS_STATUS_SUCCESS) - return absl::InternalError(ErrorString(status, file, line, expr)); - return absl::OkStatus(); -} - -absl::StatusOr> MakeBatchPointers( - cudaStream_t stream, void* buffer, void* dev_ptrs, int batch, - int batch_elem_size) { - char* ptr = static_cast(buffer); - auto host_ptrs = absl::make_unique(batch); - for (int i = 0; i < batch; ++i) { - host_ptrs[i] = ptr; - ptr += batch_elem_size; - } - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudaMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, - cudaMemcpyHostToDevice, stream))); - return std::move(host_ptrs); -} -} // namespace jax diff --git a/jaxlib/cuda/cuda_lu_pivot_kernels.cc b/jaxlib/cuda/cuda_lu_pivot_kernels.cc deleted file mode 100644 index 06c0dfd3d..000000000 --- a/jaxlib/cuda/cuda_lu_pivot_kernels.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/cuda/cuda_lu_pivot_kernels.h" - -#include - -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" -#include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { -namespace { - -absl::Status CudaLuPivotsToPermutation_(cudaStream_t stream, void** buffers, - const char* opaque, - std::size_t opaque_len) { - auto s = - UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - LaunchLuPivotsToPermutationKernel(stream, buffers, **s); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError())); - return absl::OkStatus(); -} - -} // namespace - -void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status) { - auto s = CudaLuPivotsToPermutation_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - std::string_view message = s.message(); - XlaCustomCallStatusSetFailure(status, message.data(), message.length()); - } -} - - -} // namespace jax diff --git a/jaxlib/cuda/cuda_lu_pivot_kernels.cu.cc b/jaxlib/cuda/cuda_lu_pivot_kernels.cu.cc deleted file mode 100644 index a26fb55f3..000000000 --- a/jaxlib/cuda/cuda_lu_pivot_kernels.cu.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/cuda/cuda_lu_pivot_kernels.h" - -#include -#include - -namespace jax { -namespace { - -__device__ void ComputePermutation(const std::int32_t* pivots, - std::int32_t* permutation_out, - const std::int32_t pivot_size, - const std::int32_t permutation_size) { - for (int i = 0; i < permutation_size; ++i) { - permutation_out[i] = i; - } - - // Compute the permutation from a sequence of transpositions encoded in the - // pivot array by applying the transpositions in order on the identity - // permutation. - for (int i = 0; i < pivot_size; ++i) { - if ((pivots[i] < 0) || (pivots[i] >= permutation_size)) { - continue; - } - std::int32_t swap_temporary = permutation_out[i]; - permutation_out[i] = permutation_out[pivots[i]]; - permutation_out[pivots[i]] = swap_temporary; - } -} - -__global__ void LuPivotsToPermutationKernel( - const std::int32_t* pivots, std::int32_t* permutation_out, - const std::int64_t batch_size, const std::int32_t pivot_size, - const std::int32_t permutation_size) { - for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - idx < batch_size; idx += blockDim.x * gridDim.x) { - // Fill in the output array with the identity permutation. - ComputePermutation(pivots + idx * pivot_size, - permutation_out + idx * permutation_size, pivot_size, - permutation_size); - } -} - -} // namespace - -void LaunchLuPivotsToPermutationKernel( - cudaStream_t stream, void** buffers, - LuPivotsToPermutationDescriptor descriptor) { - const std::int32_t* pivots = - reinterpret_cast(buffers[0]); - std::int32_t* permutation_out = reinterpret_cast(buffers[1]); - - const int block_dim = 128; - const std::int64_t grid_dim = std::min( - 1024, (descriptor.batch_size + block_dim - 1) / block_dim); - - LuPivotsToPermutationKernel<<>>( - pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size, - descriptor.permutation_size); -} - -} // namespace jax diff --git a/jaxlib/cuda/cuda_lu_pivot_kernels.h b/jaxlib/cuda/cuda_lu_pivot_kernels.h deleted file mode 100644 index 673672dcb..000000000 --- a/jaxlib/cuda/cuda_lu_pivot_kernels.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_CUDA_LU_PIVOT_KERNELS_H_ -#define JAXLIB_CUDA_LU_PIVOT_KERNELS_H_ - -#include -#include - -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { - -struct LuPivotsToPermutationDescriptor { - std::int64_t batch_size; - std::int32_t pivot_size; - std::int32_t permutation_size; -}; - -void LaunchLuPivotsToPermutationKernel( - cudaStream_t stream, void** buffers, - LuPivotsToPermutationDescriptor descriptor); - -void CudaLuPivotsToPermutation(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status); - -} // namespace jax - -#endif // JAXLIB_CUDA_LU_PIVOT_KERNELS_H_ diff --git a/jaxlib/cuda/cusparse.cc b/jaxlib/cuda/cusparse.cc deleted file mode 100644 index 0deb0b8b0..000000000 --- a/jaxlib/cuda/cusparse.cc +++ /dev/null @@ -1,618 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "third_party/gpus/cuda/include/cusparse.h" - -#include -#include -#include -#include -#include - -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "third_party/gpus/cuda/include/cuComplex.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" -#include "jaxlib/cuda/cusparse_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" -#include "include/pybind11/numpy.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" - -namespace py = pybind11; - -namespace jax { -namespace { - -cusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) { - static auto* types = - new absl::flat_hash_map, cusparseIndexType_t>({ - {{'u', 2}, CUSPARSE_INDEX_16U}, - {{'i', 4}, CUSPARSE_INDEX_32I}, - {{'i', 8}, CUSPARSE_INDEX_64I}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - throw std::invalid_argument( - absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type))); - } - return it->second; -} - -cudaDataType DtypeToCudaDataType(const py::dtype& np_type) { - static auto* types = - new absl::flat_hash_map, cudaDataType>({ - {{'f', 2}, CUDA_R_16F}, {{'c', 4}, CUDA_C_16F}, {{'f', 4}, CUDA_R_32F}, - {{'c', 8}, CUDA_C_32F}, {{'f', 8}, CUDA_R_64F}, - {{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I}, - {{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I}, - {{'u', 4}, CUDA_R_32U}, -#if JAX_CUSPARSE_11300 - {{'V', 2}, CUDA_R_16BF}, -#endif - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - throw std::invalid_argument( - absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type))); - } - return it->second; -} -// Returns the descriptor for a Sparse matrix. -SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype, - const py::dtype& index_dtype, - int rows, int cols, int nnz, - int batch_count, - int batch_stride) { - cudaDataType value_type = DtypeToCudaDataType(data_dtype); - cusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype); - return SparseMatDescriptor{ - value_type, index_type, rows, cols, nnz, batch_count, batch_stride}; -} - -// Returns the descriptor for a Dense matrix. -DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype, - int rows, int cols, int batch_count, - int batch_stride) { - cudaDataType value_type = DtypeToCudaDataType(data_dtype); - return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride}; -} - -// Returns the descriptor for a Dense vector. -DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype, - int size) { - cudaDataType value_type = DtypeToCudaDataType(data_dtype); - return DenseVecDescriptor{value_type, size}; -} - -#if JAX_CUSPARSE_11300 -// CsrToDense: Convert CSR matrix to dense matrix - -// Returns the descriptor for a Sparse matrix. -std::pair BuildCsrToDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, - int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - SparseMatDescriptor d = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count*/1, /*batch_stride*/0); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - - // buffer_size does not reference these pointers, but does error on NULL. - // TODO(jakevdp): check whether this is documented. - int val = 0; - void* empty = &val; - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( - &mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, - d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); - size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize( - handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT, - &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - - return {buffer_size, PackDescriptor(d)}; -} - -absl::Status CsrToDense_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[2], - /*csrColInd=*/buffers[1], - /*csrValues=*/buffers[0], d.index_type, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSparseToDense(handle.get(), mat_a, mat_b, - CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - return absl::OkStatus(); -} - -void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CsrFromDense: Convert dense matrix to CSR matrix - -// Returns the descriptor for a CsrFromDense operation. -std::pair BuildCsrFromDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, - int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - SparseMatDescriptor d = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count=*/1, /*batch_stride=*/0); - - cusparseDnMatDescr_t mat_a = 0; - cusparseSpMatDescr_t mat_b = 0; - - // bufferSize does not reference these pointers, but does error on NULL. - int val = 0; - void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( - &mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, - d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); - - return {buffer_size, PackDescriptor(d)}; -} - -absl::Status CsrFromDense_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseDnMatDescr_t mat_a = 0; - cusparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[3], - /*csrColInd=*/buffers[2], - /*csrValues=*/buffers[1], d.index_type, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); - return absl::OkStatus(); -} - -void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CsrMatvec: Product of CSR matrix and dense vector. - -// Returns the descriptor for a CsrMatvec operation. -std::pair BuildCsrMatvecDescriptor( - const py::dtype& data_dtype, const py::dtype& x_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - SparseMatDescriptor A = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count=*/1, /*batch_stride=*/0); - DenseVecDescriptor x = - BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols); - DenseVecDescriptor y = - BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnVecDescr_t vec_x = 0; - cusparseDnVecDescr_t vec_y = 0; - cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE - : CUSPARSE_OPERATION_NON_TRANSPOSE; - - // bufferSize does not reference these pointers, but does error on NULL. - int val = 0; - void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( - &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, - A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type))); - size_t buffer_size; - CudaConst alpha = CudaOne(y.type); - CudaConst beta = CudaZero(y.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( - handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); - - return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})}; -} - -// CsrMatmat: Product of CSR matrix and dense matrix. - -// Returns the descriptor for a CsrMatmat operation. -std::pair BuildCsrMatmatDescriptor( - const py::dtype& data_dtype, const py::dtype& b_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - SparseMatDescriptor A = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count=*/1, /*batch_stride=*/0); - DenseMatDescriptor B = - BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols, - /*batch_count=*/1, /*batch_stride=*/0); - DenseMatDescriptor C = - BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols, - /*batch_count=*/1, /*batch_stride=*/0); - cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE - : CUSPARSE_OPERATION_NON_TRANSPOSE; - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - cusparseDnMatDescr_t mat_c = 0; - - // bufferSize does not reference these pointers, but does error on NULL. - int val = 0; - void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateCsr( - &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, - A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, - empty, B.type, CUSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, - empty, C.type, CUSPARSE_ORDER_ROW))); - size_t buffer_size; - CudaConst alpha = CudaOne(C.type); - CudaConst beta = CudaZero(C.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize( - handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, - mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); - - return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})}; -} - -// CooToDense: Convert COO matrix to dense matrix - -// Returns the descriptor for a CooToDense operation. -std::pair BuildCooToDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, - int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - SparseMatDescriptor d = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count=*/1, /*batch_stride=*/0); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - - // bufferSize does not reference these pointers, but does error on NULL. - int val = 0; - void* empty = &val; - - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, - d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); - size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSparseToDense_bufferSize( - handle.get(), mat_a, mat_b, CUSPARSE_SPARSETODENSE_ALG_DEFAULT, - &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - - return {buffer_size, PackDescriptor(d)}; -} - -// CooFromDense: Convert dense matrix to COO matrix - -// Returns the descriptor for a CooFromDense operation. -std::pair BuildCooFromDenseDescriptor( - const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, - int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - SparseMatDescriptor d = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count=*/1, /*batch_stride=*/0); - - cusparseDnMatDescr_t mat_a = 0; - cusparseSpMatDescr_t mat_b = 0; - - // bufferSize does not reference these pointers, but does error on NULL. - int val = 0; - void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, CUSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, - d.index_type, CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_bufferSize( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); - - return {buffer_size, PackDescriptor(d)}; -} - -// CooMatvec: Product of COO matrix and dense vector. - -// Returns the descriptor for a CooMatvec operation. -std::pair BuildCooMatvecDescriptor( - const py::dtype& data_dtype, const py::dtype& x_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - SparseMatDescriptor A = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count=*/1, /*batch_stride=*/0); - DenseVecDescriptor x = - BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols); - DenseVecDescriptor y = - BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnVecDescr_t vec_x = 0; - cusparseDnVecDescr_t vec_y = 0; - cusparseOperation_t op = transpose ? CUSPARSE_OPERATION_TRANSPOSE - : CUSPARSE_OPERATION_NON_TRANSPOSE; - - // bufferSize does not reference these pointers, but does error on NULL. - int val = 0; - void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, - A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, x.size, empty, x.type))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, y.size, empty, y.type))); - size_t buffer_size; - CudaConst alpha = CudaOne(y.type); - CudaConst beta = CudaZero(y.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMV_bufferSize( - handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - CUSPARSE_MV_ALG_DEFAULT, &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); - - return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})}; -} - -// CooMatmat: Product of COO matrix and dense matrix. - -// Returns the descriptor for a CooMatmat operation. -std::pair BuildCooMatmatDescriptor( - const py::dtype& data_dtype, const py::dtype& b_dtype, - const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose, int batch_count, - int lhs_batch_stride, int rhs_batch_stride) { - // Three batch modes are supported, C_i = A_i B, C_i = A B_i, and - // Ci = A_i B_i, where `i` denotes the batch dimension. - // All three matrices A, B, and C must have the same batch count. - // Use batch stride to trigger individual mode, e.g., - // `rhs_batch_stride = 0` for C_i = A_i B. - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - - SparseMatDescriptor A = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - batch_count, lhs_batch_stride); - DenseMatDescriptor B = - BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols, - batch_count, rhs_batch_stride); - int C_rows = (transpose == true) ? cols : rows; - // TODO(tianjianlu): enable the selection of batch stride. - // The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643) - // in cusparse library does not allow batch_stride = 0. - // int C_batch_stride = (batch_count > 1)? C_rows * BCcols : 0; - int C_batch_stride = C_rows * BCcols; - DenseMatDescriptor C = - BuildDenseMatDescriptor(compute_dtype, /*rows=*/C_rows, /*cols=*/BCcols, - batch_count, C_batch_stride); - cusparseOperation_t op_A = transpose ? CUSPARSE_OPERATION_TRANSPOSE - : CUSPARSE_OPERATION_NON_TRANSPOSE; - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - cusparseDnMatDescr_t mat_c = 0; - - - // bufferSize does not reference these pointers, but does error on NULL. - int val = 0; - void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, - A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusparseCooSetStridedBatch( - mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, - empty, B.type, CUSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusparseDnMatSetStridedBatch( - mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride))); - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, - empty, C.type, CUSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusparseDnMatSetStridedBatch( - mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride))); - size_t buffer_size; - CudaConst alpha = CudaOne(C.type); - CudaConst beta = CudaZero(C.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseSpMM_bufferSize( - handle.get(), op_A, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, - mat_b, &beta, mat_c, C.type, CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); - - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); - - return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})}; -} - -#endif // if JAX_CUSPARSE_11300 - -py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) { - return PackDescriptor(Gtsv2Descriptor{m, n, ldb}); -} - -template -size_t Gtsv2BufferSize(F f, int m, int n, int ldb) { - auto h = SparseHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - size_t size; - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr, - /*du=*/nullptr, /*B=*/nullptr, ldb, &size))); - return size; -} - -size_t Gtsv2BufferSizeF32(int m, int n, int ldb) { - return Gtsv2BufferSize(cusparseSgtsv2_bufferSizeExt, m, n, ldb); -} - -size_t Gtsv2BufferSizeF64(int m, int n, int ldb) { - return Gtsv2BufferSize(cusparseDgtsv2_bufferSizeExt, m, n, ldb); -} - -py::dict Registrations() { - py::dict dict; -#if JAX_CUSPARSE_11300 - dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense); - dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense); - dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec); - dict["cusparse_csr_matmat"] = EncapsulateFunction(CsrMatmat); - dict["cusparse_coo_todense"] = EncapsulateFunction(CooToDense); - dict["cusparse_coo_fromdense"] = EncapsulateFunction(CooFromDense); - dict["cusparse_coo_matvec"] = EncapsulateFunction(CooMatvec); - dict["cusparse_coo_matmat"] = EncapsulateFunction(CooMatmat); -#endif - dict["cusparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32); - dict["cusparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64); - // TODO(tomhennigan): Add support for gtsv2 complex 32/64. - return dict; -} - -PYBIND11_MODULE(_cusparse, m) { - m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11300); - m.def("registrations", &Registrations); -#if JAX_CUSPARSE_11300 - m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor); - m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor); - m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor); - m.def("build_csr_matmat_descriptor", &BuildCsrMatmatDescriptor); - m.def("build_coo_todense_descriptor", &BuildCooToDenseDescriptor); - m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor); - m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor); - m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor); -#endif - m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32); - m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64); - m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor); -} - -} // namespace -} // namespace jax diff --git a/jaxlib/cuda/cusparse_kernels.cc b/jaxlib/cuda/cusparse_kernels.cc deleted file mode 100644 index 44e2877d3..000000000 --- a/jaxlib/cuda/cusparse_kernels.cc +++ /dev/null @@ -1,618 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/cuda/cusparse_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "third_party/gpus/cuda/include/cuComplex.h" -#include "third_party/gpus/cuda/include/cuda.h" -#if JAX_CUDA_11080 -#include "third_party/gpus/cuda/include/cuda_fp8.h" -#endif -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "third_party/gpus/cuda/include/cusparse.h" -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" -#include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -// cuSPARSE generic APIs are not supported on Windows until 11.0 -// cusparseIndexType_t is used in very limited scope so manually define will -// workaround compiling issue without harm. -#if defined(_WIN32) && (CUSPARSE_VERSION < 11000) -typedef enum { - CUSPARSE_INDEX_16U = 1, - CUSPARSE_INDEX_32I = 2, - CUSPARSE_INDEX_64I = 3 -} cusparseIndexType_t; -#endif - -namespace jax { - -template <> -/*static*/ absl::StatusOr SparseHandlePool::Borrow( - cudaStream_t stream) { - SparseHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); - cusparseHandle_t handle; - if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreate(&handle))); - } else { - handle = pool->handles_[stream].back(); - pool->handles_[stream].pop_back(); - } - if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSetStream(handle, stream))); - } - return Handle(pool, handle, stream); -} - -CudaConst CudaZero(cudaDataType type) { - CudaConst c; - std::memset(&c, 0, sizeof(c)); - return c; -} - -CudaConst CudaOne(cudaDataType type) { - CudaConst c; - std::memset(&c, 0, sizeof(c)); - switch (type) { -#if JAX_CUSPARSE_11300 - // TODO(jakevdp): 4I/4U here might break on big endian platforms. - case CUDA_R_4I: - case CUDA_C_4I: -#endif - case CUDA_R_8I: - case CUDA_C_8I: - c.i8[0] = 1; - break; -#if JAX_CUSPARSE_11300 - case CUDA_R_4U: - case CUDA_C_4U: -#endif - case CUDA_R_8U: - case CUDA_C_8U: - c.u8[0] = 1; - break; -#if JAX_CUSPARSE_11300 - case CUDA_R_16I: - case CUDA_C_16I: - c.i16[0] = 1; - break; - case CUDA_R_16U: - case CUDA_C_16U: - c.u16[0] = 1; - break; -#endif - case CUDA_R_32I: - case CUDA_C_32I: - c.i32[0] = 1; - break; - case CUDA_R_32U: - case CUDA_C_32U: - c.u32[0] = 1; - break; -#if JAX_CUSPARSE_11300 - case CUDA_R_64I: - case CUDA_C_64I: - c.i64[0] = 1; - break; - case CUDA_R_64U: - case CUDA_C_64U: - c.u64[0] = 1; - break; -#endif -#if JAX_CUDA_11080 - case CUDA_R_8F_E4M3: - c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E4M3); - break; - case CUDA_R_8F_E5M2: - c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E5M2); - break; -#endif - // TODO(jakevdp): 16F/16BF here might break on big endian platforms. - case CUDA_R_16F: - case CUDA_C_16F: - c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16 - break; -#if JAX_CUSPARSE_11300 - case CUDA_R_16BF: - case CUDA_C_16BF: - c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16 - break; -#endif - case CUDA_R_32F: - case CUDA_C_32F: - c.f32[0] = 1.0; - break; - case CUDA_R_64F: - case CUDA_C_64F: - c.f64[0] = 1.0; - break; - } - return c; -} - -#if JAX_CUSPARSE_11300 -// CsrToDense: Convert CSR matrix to dense matrix - -static absl::Status CsrToDense_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[2], - /*csrColInd=*/buffers[1], - /*csrValues=*/buffers[0], d.index_type, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSparseToDense(handle.get(), mat_a, mat_b, - CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - return absl::OkStatus(); -} - -void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CsrFromDense: Convert dense matrix to CSR matrix - -static absl::Status CsrFromDense_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseDnMatDescr_t mat_a = 0; - cusparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[3], - /*csrColInd=*/buffers[2], - /*csrValues=*/buffers[1], d.index_type, d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); - return absl::OkStatus(); -} - -void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CsrMatvec: Product of CSR matrix and dense vector. - -static absl::Status CsrMatvec_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const CsrMatvecDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - void* csr_values = buffers[0]; - void* csr_col_ind = buffers[1]; - void* csr_row_offsets = buffers[2]; - void* xbuf = buffers[3]; - void* ybuf = buffers[4]; - void* buf = buffers[5]; - - // TODO(jakevdp): alpha and beta should be user-specifiable, but constants - // are sufficient for basic matvec operations. - // Note that, contrary to cusparse docs, alpha and beta must be host pointers - // or else the operation will segfault. - CudaConst alpha = CudaOne(d.y.type); - CudaConst beta = CudaZero(d.y.type); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnVecDescr_t vec_x = 0; - cusparseDnVecDescr_t vec_y = 0; - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, - csr_col_ind, csr_values, d.A.index_type, d.A.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); - return absl::OkStatus(); -} - -void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CsrMatmat: Product of CSR matrix and dense matrix. - -static absl::Status CsrMatmat_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const CsrMatmatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - void* csr_values = buffers[0]; - void* csr_col_ind = buffers[1]; - void* csr_row_offsets = buffers[2]; - void* Bbuf = buffers[3]; - void* Cbuf = buffers[4]; - void* buf = buffers[5]; - - // TODO(jakevdp): alpha and beta should be user-specifiable, but constants - // are sufficient for basic matvec operations. - // Note that, contrary to cusparse docs, alpha and beta must be host pointers - // or else the operation will segfault. - CudaConst alpha = CudaOne(d.C.type); - CudaConst beta = CudaZero(d.C.type); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - cusparseDnMatDescr_t mat_c = 0; - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCreateCsr(&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, - csr_col_ind, csr_values, d.A.index_type, d.A.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_b, d.B.rows, d.B.cols, - /*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_c, d.C.rows, d.C.cols, - /*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM( - handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, - mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); - return absl::OkStatus(); -} - -void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CooToDense: Convert COO matrix to dense matrix - -static absl::Status CooToDense_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, - /*cooRowInd=*/buffers[1], - /*cooColInd=*/buffers[2], - /*cooValues=*/buffers[0], d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, CUSPARSE_ORDER_ROW))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSparseToDense(handle.get(), mat_a, mat_b, - CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - return absl::OkStatus(); -} - -void CooToDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CooFromDense: Convert dense matrix to COO matrix - -static absl::Status CooFromDense_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseDnMatDescr_t mat_a = 0; - cusparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, CUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, - /*cooRowInd=*/buffers[2], - /*cooColInd=*/buffers[3], - /*cooValues=*/buffers[1], d.index_type, - CUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_b))); - return absl::OkStatus(); -} - -void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CooMatvec: Product of COO matrix and dense vector. - -static absl::Status CooMatvec_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const CooMatvecDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - void* coo_values = buffers[0]; - void* coo_row_ind = buffers[1]; - void* coo_col_ind = buffers[2]; - void* xbuf = buffers[3]; - void* ybuf = buffers[4]; - void* buf = buffers[5]; - - // TODO(jakevdp): alpha and beta should be user-specifiable, but constants - // are sufficient for basic matvec operations. - // Note that, contrary to cusparse docs, alpha and beta must be host pointers - // or else the operation will segfault. - CudaConst alpha = CudaOne(d.y.type); - CudaConst beta = CudaZero(d.y.type); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnVecDescr_t vec_x = 0; - cusparseDnVecDescr_t vec_y = 0; - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo( - &mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values, - d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, CUSPARSE_MV_ALG_DEFAULT, buf))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_x))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnVec(vec_y))); - return absl::OkStatus(); -} - -void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// CooMatmat: Product of COO matrix and dense matrix. - -static absl::Status CooMatmat_(cudaStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const CooMatmatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - void* coo_values = buffers[0]; - void* coo_row_ind = buffers[1]; - void* coo_col_ind = buffers[2]; - void* Bbuf = buffers[3]; - void* Cbuf = buffers[4]; - void* buf = buffers[5]; - - // TODO(jakevdp): alpha and beta should be user-specifiable, but constants - // are sufficient for basic matvec operations. - // Note that, contrary to cusparse docs, alpha and beta must be host pointers - // or else the operation will segfault. - CudaConst alpha = CudaOne(d.C.type); - CudaConst beta = CudaZero(d.C.type); - - cusparseSpMatDescr_t mat_a = 0; - cusparseDnMatDescr_t mat_b = 0; - cusparseDnMatDescr_t mat_c = 0; - - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo( - &mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values, - d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count, - /*batchStride=*/d.A.batch_stride))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_b, d.B.rows, d.B.cols, - /*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count, - /*batchStride=*/d.B.batch_stride))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat( - &mat_c, d.C.rows, d.C.cols, - /*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count, - /*batchStride=*/d.C.batch_stride))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM( - handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, - mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_b))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseDestroyDnMat(mat_c))); - return absl::OkStatus(); -} - -void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} -#endif // if JAX_CUSPARSE_11300 - -template -static absl::Status gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { - auto h = SparseHandlePool::Borrow(); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const Gtsv2Descriptor& descriptor = **s; - int m = descriptor.m; - int n = descriptor.n; - int ldb = descriptor.ldb; - - const T* dl = (const T*)(buffers[0]); - const T* d = (const T*)(buffers[1]); - const T* du = (const T*)(buffers[2]); - const T* B = (T*)(buffers[3]); - T* X = (T*)(buffers[4]); - void* buffer = buffers[5]; - - // The solution X is written in place to B. We need to therefore copy the - // contents of B into the output buffer X and pass that into the kernel as B. - // Once copy insertion is supported for custom call aliasing, we could alias B - // with X and avoid the copy, the code below is written defensively assuming B - // and X might alias, but today we know they will not. - // TODO(b/182906199): Update the comment here once copy insertion is WAI. - if (X != B) { - size_t B_bytes = ldb * n * sizeof(T); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream))); - } - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer))); - return absl::OkStatus(); -} - -void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(cusparseSgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(cusparseDgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace jax diff --git a/jaxlib/cuda/cusparse_kernels.h b/jaxlib/cuda/cusparse_kernels.h deleted file mode 100644 index 6c429cb48..000000000 --- a/jaxlib/cuda/cusparse_kernels.h +++ /dev/null @@ -1,160 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_CUSPARSE_KERNELS_H_ -#define JAXLIB_CUSPARSE_KERNELS_H_ - -#include -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "third_party/gpus/cuda/include/cuComplex.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "third_party/gpus/cuda/include/cusparse.h" -#include "jaxlib/handle_pool.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -// Some functionality defined here is only available in CUSPARSE 11.3 or newer. -#define JAX_CUSPARSE_11300 (CUSPARSE_VERSION >= 11300) -// CUDA-11.8 introduces FP8 E4M3/E5M2 types. -#define JAX_CUDA_11080 (CUDA_VERSION >= 11080) - -namespace jax { - -using SparseHandlePool = HandlePool; - -template <> -/*static*/ absl::StatusOr SparseHandlePool::Borrow( - cudaStream_t stream); - -union CudaConst { - int8_t i8[2]; - int16_t i16[2]; - int32_t i32[2]; - int64_t i64[2]; - uint8_t u8[2]; - uint16_t u16[2]; - uint32_t u32[2]; - uint64_t u64[2]; - float f32[2]; - double f64[2]; -}; - -CudaConst CudaZero(cudaDataType type); -CudaConst CudaOne(cudaDataType type); - -struct SparseMatDescriptor { - cudaDataType value_type; - cusparseIndexType_t index_type; - int rows, cols, nnz; - int batch_count = 1; - int batch_stride = 0; -}; - -struct DenseMatDescriptor { - cudaDataType type; - int rows, cols; - int batch_count = 1; - int batch_stride = 0; -}; - -struct DenseVecDescriptor { - cudaDataType type; - int size; -}; - -#if JAX_CUSPARSE_11300 -// CsrToDense: Convert CSR matrix to dense matrix - -void CsrToDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrFromDense: Convert dense matrix to CSR matrix - -void CsrFromDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatvec: Product of CSR matrix and dense vector. - -struct CsrMatvecDescriptor { - SparseMatDescriptor A; - DenseVecDescriptor x, y; - cusparseOperation_t op; -}; - -void CsrMatvec(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatmat: Product of CSR matrix and dense matrix. - -struct CsrMatmatDescriptor { - SparseMatDescriptor A; - DenseMatDescriptor B, C; - cusparseOperation_t op_A; -}; - -void CsrMatmat(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooToDense: Convert COO matrix to dense matrix - -void CooToDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooFromDense: Convert dense matrix to COO matrix - -void CooFromDense(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatvec: Product of COO matrix and dense vector. - -struct CooMatvecDescriptor { - SparseMatDescriptor A; - DenseVecDescriptor x, y; - cusparseOperation_t op; -}; - -void CooMatvec(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatmat: Product of COO matrix and dense matrix. - -struct CooMatmatDescriptor { - SparseMatDescriptor A; - DenseMatDescriptor B, C; - cusparseOperation_t op_A; -}; - -void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -#endif // if JAX_CUSPARSE_11300 - -struct Gtsv2Descriptor { - int m, n, ldb; -}; - -void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - -void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace jax - -#endif // JAXLIB_CUSPARSE_KERNELS_H_ diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD new file mode 100644 index 000000000..59ae37d6e --- /dev/null +++ b/jaxlib/gpu/BUILD @@ -0,0 +1,43 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Shared CUDA/ROCM GPU kernels. + +licenses(["notice"]) + +package(default_visibility = ["//:__subpackages__"]) + +exports_files(srcs = [ + "blas.cc", + "blas_kernels.cc", + "blas_kernels.h", + "gpu_kernel_helpers.cc", + "gpu_kernel_helpers.h", + "gpu_kernels.cc", + "linalg.cc", + "lu_pivot_kernels.cc", + "lu_pivot_kernels.cu.cc", + "lu_pivot_kernels.h", + "prng.cc", + "prng_kernels.cc", + "prng_kernels.cu.cc", + "prng_kernels.h", + "solver.cc", + "solver_kernels.cc", + "solver_kernels.h", + "sparse.cc", + "sparse_kernels.cc", + "sparse_kernels.h", + "vendor.h", +]) diff --git a/jaxlib/rocm/hipblas.cc b/jaxlib/gpu/blas.cc similarity index 74% rename from jaxlib/rocm/hipblas.cc rename to jaxlib/gpu/blas.cc index aa880f38c..5f868b37c 100644 --- a/jaxlib/rocm/hipblas.cc +++ b/jaxlib/gpu/blas.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The JAX Authors. +/* Copyright 2019 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,28 +20,27 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" +#include "jaxlib/gpu/blas_kernels.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_pybind11_helpers.h" #include "include/pybind11/numpy.h" #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" -#include "jaxlib/rocm/hipblas_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipblas.h" namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { namespace py = pybind11; // Converts a NumPy dtype to a Type. -HipblasType DtypeToHipblasType(const py::dtype& np_type) { - static auto* types = - new absl::flat_hash_map, HipblasType>({ - {{'f', 4}, HipblasType::F32}, - {{'f', 8}, HipblasType::F64}, - {{'c', 8}, HipblasType::C64}, - {{'c', 16}, HipblasType::C128}, - }); +BlasType DtypeToBlasType(const py::dtype& np_type) { + static auto* types = new absl::flat_hash_map, BlasType>({ + {{'f', 4}, BlasType::F32}, + {{'f', 8}, BlasType::F64}, + {{'c', 8}, BlasType::C64}, + {{'c', 16}, BlasType::C128}, + }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { throw std::invalid_argument( @@ -53,7 +52,7 @@ HipblasType DtypeToHipblasType(const py::dtype& np_type) { // Returns the descriptor for a GetrfBatched operation. std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, int b, int n) { - HipblasType type = DtypeToHipblasType(dtype); + BlasType type = DtypeToBlasType(dtype); size_t size = b * sizeof(void*); return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; } @@ -61,23 +60,24 @@ std::pair BuildGetrfBatchedDescriptor(const py::dtype& dtype, // Returns the descriptor for a GetrfBatched operation. std::pair BuildGeqrfBatchedDescriptor(const py::dtype& dtype, int b, int m, int n) { - HipblasType type = DtypeToHipblasType(dtype); + BlasType type = DtypeToBlasType(dtype); size_t size = b * sizeof(void*); return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; } py::dict Registrations() { py::dict dict; - dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched); - dict["hipblas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); + dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); + dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); return dict; } -PYBIND11_MODULE(_hipblas, m) { +PYBIND11_MODULE(_blas, m) { m.def("registrations", &Registrations); m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); } } // namespace +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/hipblas_kernels.cc b/jaxlib/gpu/blas_kernels.cc similarity index 65% rename from jaxlib/rocm/hipblas_kernels.cc rename to jaxlib/gpu/blas_kernels.cc index 10ab69500..551c9112c 100644 --- a/jaxlib/rocm/hipblas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The JAX Authors. +/* Copyright 2019 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,60 +13,61 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/rocm/hipblas_kernels.h" +#include "jaxlib/gpu/blas_kernels.h" #include #include #include #include -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipblas.h" #include "absl/base/casts.h" #include "absl/base/thread_annotations.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { -using BlasHandlePool = HandlePool; +using BlasHandlePool = HandlePool; template <> /*static*/ absl::StatusOr BlasHandlePool::Borrow( - hipStream_t stream) { + gpuStream_t stream) { BlasHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); - hipblasHandle_t handle; + gpublasHandle_t handle; if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCreate(&handle))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCreate(&handle))); } else { handle = pool->handles_[stream].back(); pool->handles_[stream].pop_back(); } if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSetStream(handle, stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream))); } return Handle(pool, handle, stream); } +namespace JAX_GPU_NAMESPACE { + namespace { -// Converts a NumPy dtype to a CublasType. +// Converts a NumPy dtype to a BlasType. -int SizeOfHipblasType(HipblasType type) { +int SizeOfBlasType(BlasType type) { switch (type) { - case HipblasType::F32: + case BlasType::F32: return sizeof(float); - case HipblasType::F64: + case BlasType::F64: return sizeof(double); - case HipblasType::C64: - return sizeof(hipComplex); - case HipblasType::C128: - return sizeof(hipDoubleComplex); + case BlasType::C64: + return sizeof(gpublasComplex); + case BlasType::C128: + return sizeof(gpublasDoubleComplex); } } @@ -74,7 +75,7 @@ int SizeOfHipblasType(HipblasType type) { // Batched LU decomposition: getrfbatched -static absl::Status GetrfBatched_(hipStream_t stream, void** buffers, +static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -83,43 +84,43 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.n * d.n, - hipMemcpyDeviceToDevice, stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( + buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n, + gpuMemcpyDeviceToDevice, stream))); } int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch, - SizeOfHipblasType(d.type) * d.n * d.n); + SizeOfBlasType(d.type) * d.n * d.n); JAX_RETURN_IF_ERROR(a_ptrs_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); switch (d.type) { - case HipblasType::F32: { + case BlasType::F32: { float** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSgetrfBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched( handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } - case HipblasType::F64: { + case BlasType::F64: { double** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDgetrfBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched( handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } - case HipblasType::C64: { - hipblasComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCgetrfBatched( + case BlasType::C64: { + gpublasComplex** batch_ptrs = static_cast(buffers[4]); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched( handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } - case HipblasType::C128: { - hipblasDoubleComplex** batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZgetrfBatched( + case BlasType::C128: { + gpublasDoubleComplex** batch_ptrs = + static_cast(buffers[4]); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched( handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); break; } @@ -127,7 +128,7 @@ static absl::Status GetrfBatched_(hipStream_t stream, void** buffers, return absl::OkStatus(); } -void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque, +void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -138,7 +139,7 @@ void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque, // Batched QR decomposition: geqrfbatched -static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers, +static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -147,56 +148,56 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.m * d.n, - hipMemcpyDeviceToDevice, stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( + buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n, + gpuMemcpyDeviceToDevice, stream))); } std::vector info(d.batch); auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, - SizeOfHipblasType(d.type) * d.m * d.n); + SizeOfBlasType(d.type) * d.m * d.n); JAX_RETURN_IF_ERROR(a_ptrs_host.status()); auto tau_ptrs_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, - SizeOfHipblasType(d.type) * std::min(d.m, d.n)); + SizeOfBlasType(d.type) * std::min(d.m, d.n)); JAX_RETURN_IF_ERROR(tau_ptrs_host.status()); // TODO(phawkins): ideally we would not need to synchronize here, but to // avoid it we need a way to keep the host-side buffer alive until the copy // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); switch (d.type) { - case HipblasType::F32: { + case BlasType::F32: { float** a_batch_ptrs = static_cast(buffers[3]); float** tau_batch_ptrs = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipblasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, + gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, tau_batch_ptrs, info.data(), d.batch))); break; } - case HipblasType::F64: { + case BlasType::F64: { double** a_batch_ptrs = static_cast(buffers[3]); double** tau_batch_ptrs = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipblasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, + gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, tau_batch_ptrs, info.data(), d.batch))); break; } - case HipblasType::C64: { - hipblasComplex** a_batch_ptrs = static_cast(buffers[3]); - hipblasComplex** tau_batch_ptrs = - static_cast(buffers[4]); + case BlasType::C64: { + gpublasComplex** a_batch_ptrs = static_cast(buffers[3]); + gpublasComplex** tau_batch_ptrs = + static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipblasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, + gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, tau_batch_ptrs, info.data(), d.batch))); break; } - case HipblasType::C128: { - hipblasDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - hipblasDoubleComplex** tau_batch_ptrs = - static_cast(buffers[4]); + case BlasType::C128: { + gpublasDoubleComplex** a_batch_ptrs = + static_cast(buffers[3]); + gpublasDoubleComplex** tau_batch_ptrs = + static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipblasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, + gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, tau_batch_ptrs, info.data(), d.batch))); break; } @@ -214,7 +215,7 @@ static absl::Status GeqrfBatched_(hipStream_t stream, void** buffers, return absl::OkStatus(); } -void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque, +void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -223,4 +224,5 @@ void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque, } } +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cublas_kernels.h b/jaxlib/gpu/blas_kernels.h similarity index 70% rename from jaxlib/cuda/cublas_kernels.h rename to jaxlib/gpu/blas_kernels.h index 218f950d3..8fc4daba1 100644 --- a/jaxlib/cuda/cublas_kernels.h +++ b/jaxlib/gpu/blas_kernels.h @@ -13,20 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_CUBLAS_KERNELS_H_ -#define JAXLIB_CUBLAS_KERNELS_H_ +#ifndef JAXLIB_GPU_BLAS_KERNELS_H_ +#define JAXLIB_GPU_BLAS_KERNELS_H_ #include -#include "third_party/gpus/cuda/include/cublas_v2.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "jaxlib/gpu/vendor.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { +namespace JAX_GPU_NAMESPACE { // Set of types known to Cusolver. -enum class CublasType { +enum class BlasType { F32, F64, C64, @@ -36,25 +35,24 @@ enum class CublasType { // Batched LU decomposition: getrfbatched struct GetrfBatchedDescriptor { - CublasType type; + BlasType type; int batch, n; }; -void GetrfBatched(cudaStream_t stream, void** buffers, const char* opaque, +void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); - // Batched QR decomposition: geqrfbatched struct GeqrfBatchedDescriptor { - CublasType type; + BlasType type; int batch, m, n; }; -void GeqrfBatched(cudaStream_t stream, void** buffers, const char* opaque, +void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); - +} // namespace JAX_GPU_NAMESPACE } // namespace jax -#endif // JAXLIB_CUBLAS_KERNELS_H_ +#endif // JAXLIB_GPU_BLAS_KERNELS_H_ diff --git a/jaxlib/rocm/hip_gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc similarity index 63% rename from jaxlib/rocm/hip_gpu_kernel_helpers.cc rename to jaxlib/gpu/gpu_kernel_helpers.cc index 722b9b584..1a11ac632 100644 --- a/jaxlib/rocm/hip_gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The JAX Authors. +/* Copyright 2019 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,83 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" - -#include +#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" namespace jax { +namespace JAX_GPU_NAMESPACE { + namespace { -std::string ErrorString(hipError_t error) { return hipGetErrorString(error); } +std::string ErrorString(gpuError_t error) { return gpuGetErrorString(error); } + +#ifdef JAX_GPU_CUDA + +std::string ErrorString(gpusparseStatus_t status) { + return cusparseGetErrorString(status); +} + +std::string ErrorString(gpusolverStatus_t status) { + switch (status) { + case CUSOLVER_STATUS_SUCCESS: + return "cuSolver success."; + case CUSOLVER_STATUS_NOT_INITIALIZED: + return "cuSolver has not been initialized"; + case CUSOLVER_STATUS_ALLOC_FAILED: + return "cuSolver allocation failed"; + case CUSOLVER_STATUS_INVALID_VALUE: + return "cuSolver invalid value error"; + case CUSOLVER_STATUS_ARCH_MISMATCH: + return "cuSolver architecture mismatch error"; + case CUSOLVER_STATUS_MAPPING_ERROR: + return "cuSolver mapping error"; + case CUSOLVER_STATUS_EXECUTION_FAILED: + return "cuSolver execution failed"; + case CUSOLVER_STATUS_INTERNAL_ERROR: + return "cuSolver internal error"; + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "cuSolver matrix type not supported error"; + case CUSOLVER_STATUS_NOT_SUPPORTED: + return "cuSolver not supported error"; + case CUSOLVER_STATUS_ZERO_PIVOT: + return "cuSolver zero pivot error"; + case CUSOLVER_STATUS_INVALID_LICENSE: + return "cuSolver invalid license error"; + default: + return absl::StrCat("Unknown cuSolver error: ", status); + } +} + +std::string ErrorString(gpublasStatus_t status) { + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "cuBlas success"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "cuBlas has not been initialized"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "cuBlas allocation failure"; + case CUBLAS_STATUS_INVALID_VALUE: + return "cuBlas invalid value error"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "cuBlas architecture mismatch"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "cuBlas mapping error"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "cuBlas execution failed"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "cuBlas internal error"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "cuBlas not supported error"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "cuBlas license error"; + default: + return "Unknown cuBlas error"; + } +} + +#else std::string ErrorString(hipsparseStatus_t status) { // TODO(reza): check and see if we can use hipify @@ -115,6 +181,8 @@ std::string ErrorString(hipblasStatus_t status) { } } +#endif + template std::string ErrorString(T status, const char* file, std::int64_t line, const char* expr) { @@ -123,37 +191,37 @@ std::string ErrorString(T status, const char* file, std::int64_t line, } } // namespace -absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line, +absl::Status AsStatus(gpuError_t error, const char* file, std::int64_t line, const char* expr) { - if (error != hipSuccess) + if (error != gpuSuccess) return absl::InternalError(ErrorString(error, file, line, expr)); return absl::OkStatus(); } -absl::Status AsStatus(hipsolverStatus_t status, const char* file, +absl::Status AsStatus(gpusolverStatus_t status, const char* file, std::int64_t line, const char* expr) { - if (status != HIPSOLVER_STATUS_SUCCESS) + if (status != GPUSOLVER_STATUS_SUCCESS) return absl::InternalError(ErrorString(status, file, line, expr)); return absl::OkStatus(); } -absl::Status AsStatus(hipsparseStatus_t status, const char* file, +absl::Status AsStatus(gpusparseStatus_t status, const char* file, std::int64_t line, const char* expr) { - if (status != HIPSPARSE_STATUS_SUCCESS) + if (status != GPUSPARSE_STATUS_SUCCESS) return absl::InternalError(ErrorString(status, file, line, expr)); return absl::OkStatus(); } -absl::Status AsStatus(hipblasStatus_t status, const char* file, +absl::Status AsStatus(gpublasStatus_t status, const char* file, std::int64_t line, const char* expr) { - if (status != HIPBLAS_STATUS_SUCCESS) + if (status != GPUBLAS_STATUS_SUCCESS) return absl::InternalError(ErrorString(status, file, line, expr)); return absl::OkStatus(); } -absl::StatusOr> -MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch, - int batch_elem_size) { +absl::StatusOr> MakeBatchPointers( + gpuStream_t stream, void* buffer, void* dev_ptrs, int batch, + int batch_elem_size) { char* ptr = static_cast(buffer); auto host_ptrs = absl::make_unique(batch); for (int i = 0; i < batch; ++i) { @@ -161,8 +229,10 @@ MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch, ptr += batch_elem_size; } JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, - hipMemcpyHostToDevice, stream))); + gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, + gpuMemcpyHostToDevice, stream))); return std::move(host_ptrs); } + +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cuda_gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h similarity index 72% rename from jaxlib/cuda/cuda_gpu_kernel_helpers.h rename to jaxlib/gpu/gpu_kernel_helpers.h index 927d1ea7f..603b1ca3f 100644 --- a/jaxlib/cuda/cuda_gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -13,19 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_ -#define JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_ +#ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ +#define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "third_party/gpus/cuda/include/cublas_v2.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "third_party/gpus/cuda/include/cusolverDn.h" -#include "third_party/gpus/cuda/include/cusparse.h" +#include "jaxlib/gpu/vendor.h" -#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr) +#define JAX_AS_STATUS(expr) \ + jax::JAX_GPU_NAMESPACE::AsStatus(expr, __FILE__, __LINE__, #expr) #define JAX_THROW_IF_ERROR(expr) \ { \ @@ -40,27 +38,29 @@ limitations under the License. } namespace jax { +namespace JAX_GPU_NAMESPACE { // Used via JAX_AS_STATUS(expr) macro. -absl::Status AsStatus(cudaError_t error, const char* file, std::int64_t line, +absl::Status AsStatus(gpuError_t error, const char* file, std::int64_t line, const char* expr); -absl::Status AsStatus(cusolverStatus_t status, const char* file, +absl::Status AsStatus(gpusolverStatus_t status, const char* file, std::int64_t line, const char* expr); -absl::Status AsStatus(cusparseStatus_t status, const char* file, +absl::Status AsStatus(gpusparseStatus_t status, const char* file, std::int64_t line, const char* expr); -absl::Status AsStatus(cublasStatus_t status, const char* file, +absl::Status AsStatus(gpublasStatus_t status, const char* file, std::int64_t line, const char* expr); // Builds an array of pointers to each array in a batch, in device memory. // Caution: the return value must be kept alive (e.g., via a stream // synchronization) until the copy enqueued by MakeBatchPointers on `stream` // completes. -absl::StatusOr> MakeBatchPointers(cudaStream_t stream, +absl::StatusOr> MakeBatchPointers(gpuStream_t stream, void* buffer, void* dev_ptrs, int batch, int batch_elem_size); +} // namespace JAX_GPU_NAMESPACE } // namespace jax -#endif // JAXLIB_CUDA_GPU_KERNEL_HELPERS_H_ +#endif // JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ diff --git a/jaxlib/cuda/cuda_gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc similarity index 85% rename from jaxlib/cuda/cuda_gpu_kernels.cc rename to jaxlib/gpu/gpu_kernels.cc index c3ef2ebcd..4d6fb9e33 100644 --- a/jaxlib/cuda/cuda_gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -16,21 +16,23 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/cuda/cublas_kernels.h" -#include "jaxlib/cuda/cuda_lu_pivot_kernels.h" -#include "jaxlib/cuda/cuda_prng_kernels.h" -#include "jaxlib/cuda/cusolver_kernels.h" -#include "jaxlib/cuda/cusparse_kernels.h" +#include "jaxlib/gpu/blas_kernels.h" +#include "jaxlib/gpu/lu_pivot_kernels.h" +#include "jaxlib/gpu/prng_kernels.h" +#include "jaxlib/gpu/solver_kernels.h" +#include "jaxlib/gpu/sparse_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_lu_pivots_to_permutation", - CudaLuPivotsToPermutation, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cuda_threefry2x32", CudaThreeFry2x32, +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_lu_pivots_to_permutation", + LuPivotsToPermutation, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_potrf", Potrf, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); @@ -66,4 +68,5 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64, "CUDA"); } // namespace +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cuda_linalg.cc b/jaxlib/gpu/linalg.cc similarity index 77% rename from jaxlib/cuda/cuda_linalg.cc rename to jaxlib/gpu/linalg.cc index b798c5bbe..a03a9c4de 100644 --- a/jaxlib/cuda/cuda_linalg.cc +++ b/jaxlib/gpu/linalg.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" -#include "jaxlib/cuda/cuda_lu_pivot_kernels.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/lu_pivot_kernels.h" #include "jaxlib/kernel_pybind11_helpers.h" #include "include/pybind11/pybind11.h" namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { -std::string BuildCudaLuPivotsToPermutationDescriptor( +std::string BuildLuPivotsToPermutationDescriptor( std::int64_t batch_size, std::int32_t pivot_size, std::int32_t permutation_size) { return PackDescriptorAsString(LuPivotsToPermutationDescriptor{ @@ -31,21 +31,22 @@ std::string BuildCudaLuPivotsToPermutationDescriptor( pybind11::dict Registrations() { pybind11::dict dict; - dict["cuda_lu_pivots_to_permutation"] = - EncapsulateFunction(CudaLuPivotsToPermutation); + dict[JAX_GPU_PREFIX "_lu_pivots_to_permutation"] = + EncapsulateFunction(LuPivotsToPermutation); return dict; } -PYBIND11_MODULE(_cuda_linalg, m) { +PYBIND11_MODULE(_linalg, m) { m.def("registrations", &Registrations); m.def("lu_pivots_to_permutation_descriptor", [](std::int64_t batch_size, std::int32_t pivot_size, std::int32_t permutation_size) { - std::string result = BuildCudaLuPivotsToPermutationDescriptor( + std::string result = BuildLuPivotsToPermutationDescriptor( batch_size, pivot_size, permutation_size); return pybind11::bytes(result); }); } } // namespace +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/hip_lu_pivot_kernels.cc b/jaxlib/gpu/lu_pivot_kernels.cc similarity index 63% rename from jaxlib/rocm/hip_lu_pivot_kernels.cc rename to jaxlib/gpu/lu_pivot_kernels.cc index 00f642c14..407dd90d6 100644 --- a/jaxlib/rocm/hip_lu_pivot_kernels.cc +++ b/jaxlib/gpu/lu_pivot_kernels.cc @@ -13,38 +13,41 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/rocm/hip_lu_pivot_kernels.h" +#include "jaxlib/gpu/lu_pivot_kernels.h" #include -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { -absl::Status HipLuPivotsToPermutation_(hipStream_t stream, void** buffers, - const char* opaque, - std::size_t opaque_len) { +absl::Status LuPivotsToPermutation_(gpuStream_t stream, void** buffers, + const char* opaque, + std::size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); LaunchLuPivotsToPermutationKernel(stream, buffers, **s); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError())); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); return absl::OkStatus(); } } // namespace -void HipLuPivotsToPermutation(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status) { - auto s = HipLuPivotsToPermutation_(stream, buffers, opaque, opaque_len); +void LuPivotsToPermutation(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + auto s = LuPivotsToPermutation_(stream, buffers, opaque, opaque_len); if (!s.ok()) { std::string_view message = s.message(); XlaCustomCallStatusSetFailure(status, message.data(), message.length()); } } +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/hip_lu_pivot_kernels.hip.cc b/jaxlib/gpu/lu_pivot_kernels.cu.cc similarity index 94% rename from jaxlib/rocm/hip_lu_pivot_kernels.hip.cc rename to jaxlib/gpu/lu_pivot_kernels.cu.cc index cf171ee15..94afa78de 100644 --- a/jaxlib/rocm/hip_lu_pivot_kernels.hip.cc +++ b/jaxlib/gpu/lu_pivot_kernels.cu.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/rocm/hip_lu_pivot_kernels.h" +#include "jaxlib/gpu/lu_pivot_kernels.h" #include #include +#include "jaxlib/gpu/vendor.h" + namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { __device__ void ComputePermutation(const std::int32_t* pivots, @@ -58,7 +61,7 @@ __global__ void LuPivotsToPermutationKernel( } // namespace void LaunchLuPivotsToPermutationKernel( - hipStream_t stream, void** buffers, + gpuStream_t stream, void** buffers, LuPivotsToPermutationDescriptor descriptor) { const std::int32_t* pivots = reinterpret_cast(buffers[0]); @@ -74,4 +77,5 @@ void LaunchLuPivotsToPermutationKernel( descriptor.permutation_size); } +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/hip_lu_pivot_kernels.h b/jaxlib/gpu/lu_pivot_kernels.h similarity index 69% rename from jaxlib/rocm/hip_lu_pivot_kernels.h rename to jaxlib/gpu/lu_pivot_kernels.h index c75dd7523..6eae513ae 100644 --- a/jaxlib/rocm/hip_lu_pivot_kernels.h +++ b/jaxlib/gpu/lu_pivot_kernels.h @@ -13,16 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_HIP_LU_PIVOT_KERNELS_H_ -#define JAXLIB_HIP_LU_PIVOT_KERNELS_H_ +#ifndef JAXLIB_GPU_LU_PIVOT_KERNELS_H_ +#define JAXLIB_GPU_LU_PIVOT_KERNELS_H_ #include #include -#include "rocm/include/hip/hip_runtime_api.h" +#include "jaxlib/gpu/vendor.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { +namespace JAX_GPU_NAMESPACE { struct LuPivotsToPermutationDescriptor { std::int64_t batch_size; @@ -31,13 +32,14 @@ struct LuPivotsToPermutationDescriptor { }; void LaunchLuPivotsToPermutationKernel( - hipStream_t stream, void** buffers, + gpuStream_t stream, void** buffers, LuPivotsToPermutationDescriptor descriptor); -void HipLuPivotsToPermutation(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status); +void LuPivotsToPermutation(gpuStream_t stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status); +} // namespace JAX_GPU_NAMESPACE } // namespace jax -#endif // JAXLIB_HIP_LU_PIVOT_KERNELS_H_ \ No newline at end of file +#endif // JAXLIB_GPU_LU_PIVOT_KERNELS_H_ diff --git a/jaxlib/cuda/cuda_prng.cc b/jaxlib/gpu/prng.cc similarity index 74% rename from jaxlib/cuda/cuda_prng.cc rename to jaxlib/gpu/prng.cc index fd9ca1246..b0d548244 100644 --- a/jaxlib/cuda/cuda_prng.cc +++ b/jaxlib/gpu/prng.cc @@ -13,31 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/cuda_prng_kernels.h" - -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/kernel_pybind11_helpers.h" #include "include/pybind11/pybind11.h" namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { -std::string BuildCudaThreeFry2x32Descriptor(std::int64_t n) { +std::string BuildThreeFry2x32Descriptor(std::int64_t n) { return PackDescriptorAsString(ThreeFry2x32Descriptor{n}); } pybind11::dict Registrations() { pybind11::dict dict; - dict["cuda_threefry2x32"] = EncapsulateFunction(CudaThreeFry2x32); + dict[JAX_GPU_PREFIX "_threefry2x32"] = EncapsulateFunction(ThreeFry2x32); return dict; } -PYBIND11_MODULE(_cuda_prng, m) { +PYBIND11_MODULE(_prng, m) { m.def("registrations", &Registrations); m.def("threefry2x32_descriptor", [](std::int64_t n) { - std::string result = BuildCudaThreeFry2x32Descriptor(n); + std::string result = BuildThreeFry2x32Descriptor(n); return pybind11::bytes(result); }); } } // namespace +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cuda_prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc similarity index 68% rename from jaxlib/cuda/cuda_prng_kernels.cc rename to jaxlib/gpu/prng_kernels.cc index a004b2a20..6101f41de 100644 --- a/jaxlib/cuda/cuda_prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -13,35 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/cuda_prng_kernels.h" +#include "jaxlib/gpu/prng_kernels.h" #include -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/kernel_helpers.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { -absl::Status CudaThreeFry2x32_(cudaStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { +absl::Status ThreeFry2x32_(gpuStream_t stream, void** buffers, + const char* opaque, std::size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); LaunchThreeFry2x32Kernel(stream, buffers, **s); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaGetLastError())); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); return absl::OkStatus(); } } // namespace -void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CudaThreeFry2x32_(stream, buffers, opaque, opaque_len); +void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + auto s = ThreeFry2x32_(stream, buffers, opaque, opaque_len); if (!s.ok()) { std::string_view message = s.message(); XlaCustomCallStatusSetFailure(status, message.data(), message.length()); } } +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cuda_prng_kernels.cu.cc b/jaxlib/gpu/prng_kernels.cu.cc similarity index 95% rename from jaxlib/cuda/cuda_prng_kernels.cu.cc rename to jaxlib/gpu/prng_kernels.cu.cc index 6eb87fe2f..37555bb78 100644 --- a/jaxlib/cuda/cuda_prng_kernels.cu.cc +++ b/jaxlib/gpu/prng_kernels.cu.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/cuda_prng_kernels.h" +#include "jaxlib/gpu/prng_kernels.h" #include #include +#include "jaxlib/gpu/vendor.h" + namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0, @@ -96,7 +99,7 @@ __global__ void ThreeFry2x32Kernel(const std::uint32_t* key0, } // namespace -void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers, +void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers, ThreeFry2x32Descriptor descriptor) { std::array keys; keys[0] = reinterpret_cast(buffers[0]); @@ -115,4 +118,5 @@ void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers, out[1], descriptor.n); } +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cuda_prng_kernels.h b/jaxlib/gpu/prng_kernels.h similarity index 68% rename from jaxlib/cuda/cuda_prng_kernels.h rename to jaxlib/gpu/prng_kernels.h index 035f1f05f..a9ea8563e 100644 --- a/jaxlib/cuda/cuda_prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -13,27 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_CUDA_PRNG_KERNELS_H_ -#define JAXLIB_CUDA_PRNG_KERNELS_H_ +#ifndef JAXLIB_GPU_PRNG_KERNELS_H_ +#define JAXLIB_GPU_PRNG_KERNELS_H_ #include #include -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "jaxlib/gpu/vendor.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { +namespace JAX_GPU_NAMESPACE { struct ThreeFry2x32Descriptor { std::int64_t n; }; -void LaunchThreeFry2x32Kernel(cudaStream_t stream, void** buffers, +void LaunchThreeFry2x32Kernel(gpuStream_t stream, void** buffers, ThreeFry2x32Descriptor descriptor); -void CudaThreeFry2x32(cudaStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); +void ThreeFry2x32(gpuStream_t stream, void** buffers, const char* opaque, + size_t opaque_len, XlaCustomCallStatus* status); +} // namespace JAX_GPU_NAMESPACE } // namespace jax -#endif // JAXLIB_CUDA_PRNG_KERNELS_H_ +#endif // JAXLIB_GPU_PRNG_KERNELS_H_ diff --git a/jaxlib/cuda/cusolver.cc b/jaxlib/gpu/solver.cc similarity index 52% rename from jaxlib/cuda/cusolver.cc rename to jaxlib/gpu/solver.cc index 7bb2a8ed5..ec2bd3e3d 100644 --- a/jaxlib/cuda/cusolver.cc +++ b/jaxlib/gpu/solver.cc @@ -21,28 +21,27 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "third_party/gpus/cuda/include/cusolverDn.h" -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" -#include "jaxlib/cuda/cusolver_kernels.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/solver_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_pybind11_helpers.h" #include "include/pybind11/numpy.h" #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { namespace py = pybind11; // Converts a NumPy dtype to a Type. -CusolverType DtypeToCusolverType(const py::dtype& np_type) { +SolverType DtypeToSolverType(const py::dtype& np_type) { static auto* types = - new absl::flat_hash_map, CusolverType>({ - {{'f', 4}, CusolverType::F32}, - {{'f', 8}, CusolverType::F64}, - {{'c', 8}, CusolverType::C64}, - {{'c', 16}, CusolverType::C128}, + new absl::flat_hash_map, SolverType>({ + {{'f', 4}, SolverType::F32}, + {{'f', 8}, SolverType::F64}, + {{'c', 8}, SolverType::C64}, + {{'c', 16}, SolverType::C128}, }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { @@ -57,48 +56,87 @@ CusolverType DtypeToCusolverType(const py::dtype& np_type) { // Returns the workspace size and a descriptor for a potrf operation. std::pair BuildPotrfDescriptor(const py::dtype& dtype, bool lower, int b, int n) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; std::int64_t workspace_size; - cublasFillMode_t uplo = - lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; if (b == 1) { switch (type) { - case CusolverType::F32: + case SolverType::F32: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnSpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); + JAX_AS_STATUS(gpusolverDnSpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); workspace_size = lwork * sizeof(float); break; - case CusolverType::F64: + case SolverType::F64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnDpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); + JAX_AS_STATUS(gpusolverDnDpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); workspace_size = lwork * sizeof(double); break; - case CusolverType::C64: + case SolverType::C64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnCpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); - workspace_size = lwork * sizeof(cuComplex); + JAX_AS_STATUS(gpusolverDnCpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); + workspace_size = lwork * sizeof(gpuComplex); break; - case CusolverType::C128: + case SolverType::C128: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnZpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); - workspace_size = lwork * sizeof(cuDoubleComplex); + JAX_AS_STATUS(gpusolverDnZpotrf_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork))); + workspace_size = lwork * sizeof(gpuDoubleComplex); break; } } else { +#ifdef JAX_GPU_CUDA // We use the workspace buffer for our own scratch space. workspace_size = sizeof(void*) * b; +#else + // TODO(rocm): when cuda and hip had same API for batched potrf, remove this + // batched potrf has different API compared to CUDA. In hip we still need to + // create the workspace and additional space to copy the batch array + // pointers + switch (type) { + case SolverType::F32: + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + hipsolverSpotrfBatched_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork, b))); + workspace_size = (lwork * sizeof(float)) + (b * sizeof(float*)); + break; + case SolverType::F64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + hipsolverDpotrfBatched_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork, b))); + workspace_size = (lwork * sizeof(double)) + (b * sizeof(double*)); + break; + case SolverType::C64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + hipsolverCpotrfBatched_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork, b))); + workspace_size = + (lwork * sizeof(hipComplex)) + (b * sizeof(hipComplex*)); + break; + case SolverType::C128: + JAX_THROW_IF_ERROR(JAX_AS_STATUS( + hipsolverZpotrfBatched_bufferSize(handle.get(), uplo, n, + /*A=*/nullptr, + /*lda=*/n, &lwork, b))); + workspace_size = (lwork * sizeof(hipDoubleComplex)) + + (b * sizeof(hipDoubleComplex*)); + break; + } +#endif // JAX_GPU_CUDA } return {workspace_size, PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})}; @@ -109,38 +147,38 @@ std::pair BuildPotrfDescriptor(const py::dtype& dtype, // Returns the workspace size and a descriptor for a getrf operation. std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, int m, int n) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; switch (type) { - case CusolverType::F32: + case SolverType::F32: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; - case CusolverType::F64: + case SolverType::F64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; - case CusolverType::C64: + case SolverType::C64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; - case CusolverType::C128: + case SolverType::C128: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; } - return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})}; + return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})}; } // geqrf: QR decomposition @@ -148,91 +186,91 @@ std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, // Returns the workspace size and a descriptor for a geqrf operation. std::pair BuildGeqrfDescriptor(const py::dtype& dtype, int b, int m, int n) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; switch (type) { - case CusolverType::F32: + case SolverType::F32: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; - case CusolverType::F64: + case SolverType::F64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; - case CusolverType::C64: + case SolverType::C64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; - case CusolverType::C128: + case SolverType::C128: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); + JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork))); break; } return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; } +#ifdef JAX_GPU_CUDA + // csrlsvqr: Linear system solve via Sparse QR // Returns a descriptor for a csrlsvqr operation. py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA, int reorder, double tol) { - CusolverType type = DtypeToCusolverType(dtype); - auto h = SpSolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - + SolverType type = DtypeToSolverType(dtype); return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol}); } +#endif // JAX_GPU_CUDA + // orgqr/ungqr: apply elementary Householder transformations // Returns the workspace size and a descriptor for a geqrf operation. std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, int m, int n, int k) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; switch (type) { - case CusolverType::F32: + case SolverType::F32: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); + JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; - case CusolverType::F64: + case SolverType::F64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); + JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; - case CusolverType::C64: + case SolverType::C64: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); + JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; - case CusolverType::C128: + case SolverType::C128: JAX_THROW_IF_ERROR( - JAX_AS_STATUS(cusolverDnZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); + JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k, + /*A=*/nullptr, + /*lda=*/m, + /*tau=*/nullptr, &lwork))); break; } return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; @@ -243,32 +281,32 @@ std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, // Returns the workspace size and a descriptor for a syevd operation. std::pair BuildSyevdDescriptor(const py::dtype& dtype, bool lower, int b, int n) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; - cublasFillMode_t uplo = - lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; switch (type) { - case CusolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevd_bufferSize( + case SolverType::F32: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork))); break; - case CusolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevd_bufferSize( + case SolverType::F64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork))); break; - case CusolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevd_bufferSize( + case SolverType::C64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork))); break; - case CusolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevd_bufferSize( + case SolverType::C128: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork))); break; @@ -282,60 +320,60 @@ std::pair BuildSyevdDescriptor(const py::dtype& dtype, // Returns the workspace size and a descriptor for a syevj_batched operation. std::pair BuildSyevjDescriptor(const py::dtype& dtype, bool lower, int batch, int n) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; - syevjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); }); - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; - cublasFillMode_t uplo = - lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + gpuSyevjInfo_t params; + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); + std::unique_ptr params_cleanup( + params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; if (batch == 1) { switch (type) { - case CusolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevj_bufferSize( + case SolverType::F32: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params))); break; - case CusolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevj_bufferSize( + case SolverType::F64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params))); break; - case CusolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevj_bufferSize( + case SolverType::C64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params))); break; - case CusolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevj_bufferSize( + case SolverType::C128: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params))); break; } } else { switch (type) { - case CusolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevjBatched_bufferSize( + case SolverType::F32: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params, batch))); break; - case CusolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevjBatched_bufferSize( + case SolverType::F64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params, batch))); break; - case CusolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevjBatched_bufferSize( + case SolverType::C64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params, batch))); break; - case CusolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevjBatched_bufferSize( + case SolverType::C128: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevjBatched_bufferSize( handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, &lwork, params, batch))); break; @@ -350,29 +388,11 @@ std::pair BuildSyevjDescriptor(const py::dtype& dtype, std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, int m, int n, bool compute_uv, bool full_matrices) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; - switch (type) { - case CusolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork))); - break; - case CusolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork))); - break; - case CusolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork))); - break; - case CusolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork))); - break; - } signed char jobu, jobvt; if (compute_uv) { if (full_matrices) { @@ -383,51 +403,71 @@ std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, } else { jobu = jobvt = 'N'; } + switch (type) { + case SolverType::F32: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd_bufferSize( + handle.get(), jobu, jobvt, m, n, &lwork))); + break; + case SolverType::F64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd_bufferSize( + handle.get(), jobu, jobvt, m, n, &lwork))); + break; + case SolverType::C64: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd_bufferSize( + handle.get(), jobu, jobvt, m, n, &lwork))); + break; + case SolverType::C128: + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd_bufferSize( + handle.get(), jobu, jobvt, m, n, &lwork))); + break; + } return {lwork, PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})}; } +#ifdef JAX_GPU_CUDA + // Singular value decomposition using Jacobi algorithm: gesvdj // Returns the workspace size and a descriptor for a gesvdj operation. std::pair BuildGesvdjDescriptor(const py::dtype& dtype, int batch, int m, int n, bool compute_uv, int econ) { - CusolverType type = DtypeToCusolverType(dtype); + SolverType type = DtypeToSolverType(dtype); auto h = SolverHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; - cusolverEigMode_t jobz = - compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + gpusolverEigMode_t jobz = + compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; gesvdjInfo_t params; JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); std::unique_ptr params_cleanup( params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); if (batch == 1) { switch (type) { - case CusolverType::F32: + case SolverType::F32: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize( handle.get(), jobz, econ, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, /*ldv=*/n, &lwork, params))); break; - case CusolverType::F64: + case SolverType::F64: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize( handle.get(), jobz, econ, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, /*ldv=*/n, &lwork, params))); break; - case CusolverType::C64: + case SolverType::C64: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize( handle.get(), jobz, econ, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, /*ldv=*/n, &lwork, params))); break; - case CusolverType::C128: + case SolverType::C128: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize( handle.get(), jobz, econ, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, @@ -437,28 +477,28 @@ std::pair BuildGesvdjDescriptor(const py::dtype& dtype, } } else { switch (type) { - case CusolverType::F32: + case SolverType::F32: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, /*ldv=*/n, &lwork, params, batch))); break; - case CusolverType::F64: + case SolverType::F64: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, /*ldv=*/n, &lwork, params, batch))); break; - case CusolverType::C64: + case SolverType::C64: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, /*ldv=*/n, &lwork, params, batch))); break; - case CusolverType::C128: + case SolverType::C128: JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched_bufferSize( handle.get(), jobz, m, n, /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, @@ -471,32 +511,40 @@ std::pair BuildGesvdjDescriptor(const py::dtype& dtype, GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})}; } +#endif // JAX_GPU_CUDA + py::dict Registrations() { py::dict dict; - dict["cusolver_potrf"] = EncapsulateFunction(Potrf); - dict["cusolver_getrf"] = EncapsulateFunction(Getrf); - dict["cusolver_geqrf"] = EncapsulateFunction(Geqrf); + dict[JAX_GPU_PREFIX "solver_potrf"] = EncapsulateFunction(Potrf); + dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf); + dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); + dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); + dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); + dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj); + dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd); + +#ifdef JAX_GPU_CUDA dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); - dict["cusolver_orgqr"] = EncapsulateFunction(Orgqr); - dict["cusolver_syevd"] = EncapsulateFunction(Syevd); - dict["cusolver_syevj"] = EncapsulateFunction(Syevj); - dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd); dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); +#endif // JAX_GPU_CUDA return dict; } -PYBIND11_MODULE(_cusolver, m) { +PYBIND11_MODULE(_solver, m) { m.def("registrations", &Registrations); m.def("build_potrf_descriptor", &BuildPotrfDescriptor); m.def("build_getrf_descriptor", &BuildGetrfDescriptor); m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); - m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor); m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); m.def("build_syevd_descriptor", &BuildSyevdDescriptor); m.def("build_syevj_descriptor", &BuildSyevjDescriptor); m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); +#ifdef JAX_GPU_CUDA + m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor); m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); +#endif // JAX_GPU_CUDA } } // namespace +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cusolver_kernels.cc b/jaxlib/gpu/solver_kernels.cc similarity index 62% rename from jaxlib/cuda/cusolver_kernels.cc rename to jaxlib/gpu/solver_kernels.cc index 7979eaceb..b92d38834 100644 --- a/jaxlib/cuda/cusolver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/cusolver_kernels.h" +#include "jaxlib/gpu/solver_kernels.h" #include #include @@ -24,38 +24,40 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "third_party/gpus/cuda/include/cusolverDn.h" -#include "third_party/gpus/cuda/include/cusolverSp.h" -#include "jaxlib/cuda/cuda_gpu_kernel_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" +#ifdef JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cusolverSp.h" +#endif // JAX_GPU_CUDA + namespace jax { template <> /*static*/ absl::StatusOr SolverHandlePool::Borrow( - cudaStream_t stream) { + gpuStream_t stream) { SolverHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); - cusolverDnHandle_t handle; + gpusolverDnHandle_t handle; if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreate(&handle))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreate(&handle))); } else { handle = pool->handles_[stream].back(); pool->handles_[stream].pop_back(); } if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSetStream(handle, stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream))); } return Handle(pool, handle, stream); } +#ifdef JAX_GPU_CUDA + template <> /*static*/ absl::StatusOr -SpSolverHandlePool::Borrow(cudaStream_t stream) { +SpSolverHandlePool::Borrow(gpuStream_t stream) { SpSolverHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); cusolverSpHandle_t handle; @@ -71,22 +73,26 @@ SpSolverHandlePool::Borrow(cudaStream_t stream) { return Handle(pool, handle, stream); } -static int SizeOfCusolverType(CusolverType type) { +#endif // JAX_GPU_CUDA + +namespace JAX_GPU_NAMESPACE { + +static int SizeOfSolverType(SolverType type) { switch (type) { - case CusolverType::F32: + case SolverType::F32: return sizeof(float); - case CusolverType::F64: + case SolverType::F64: return sizeof(double); - case CusolverType::C64: - return sizeof(cuComplex); - case CusolverType::C128: - return sizeof(cuDoubleComplex); + case SolverType::C64: + return sizeof(gpuComplex); + case SolverType::C128: + return sizeof(gpuDoubleComplex); } } // potrf: Cholesky decomposition -static absl::Status Potrf_(cudaStream_t stream, void** buffers, +static absl::Status Potrf_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -95,77 +101,84 @@ static absl::Status Potrf_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cudaMemcpyAsync(buffers[1], buffers[0], - SizeOfCusolverType(d.type) * d.batch * d.n * d.n, - cudaMemcpyDeviceToDevice, stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( + buffers[1], buffers[0], SizeOfSolverType(d.type) * d.batch * d.n * d.n, + gpuMemcpyDeviceToDevice, stream))); } int* info = static_cast(buffers[2]); void* workspace = buffers[3]; if (d.batch == 1) { switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); + gpusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, + static_cast(workspace), d.lwork, info))); break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnDpotrf(handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); + gpusolverDnDpotrf(handle.get(), d.uplo, d.n, a, d.n, + static_cast(workspace), d.lwork, info))); break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCpotrf( + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCpotrf( handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); + static_cast(workspace), d.lwork, info))); break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZpotrf( + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZpotrf( handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); + static_cast(workspace), d.lwork, info))); break; } } } else { auto buffer_ptrs_host = MakeBatchPointers(stream, buffers[1], workspace, d.batch, - SizeOfCusolverType(d.type) * d.n * d.n); + SizeOfSolverType(d.type) * d.n * d.n); JAX_RETURN_IF_ERROR(buffer_ptrs_host.status()); // Make sure that accesses to buffer_ptrs_host complete before we delete it. // TODO(phawkins): avoid synchronization here. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaStreamSynchronize(stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); switch (d.type) { - case CusolverType::F32: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSpotrfBatched( + case SolverType::F32: { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - - info, d.batch))); + reinterpret_cast(static_cast(workspace) + d.batch), + d.lwork, info, d.batch))); break; } - case CusolverType::F64: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDpotrfBatched( + case SolverType::F64: { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDpotrfBatched( handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - info, d.batch))); + reinterpret_cast(static_cast(workspace) + + d.batch), + d.lwork, info, d.batch))); break; } - case CusolverType::C64: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCpotrfBatched( - handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - info, d.batch))); + case SolverType::C64: { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCpotrfBatched( + handle.get(), d.uplo, d.n, static_cast(workspace), + d.n, + reinterpret_cast(static_cast(workspace) + + d.batch), + d.lwork, info, d.batch))); break; } - case CusolverType::C128: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZpotrfBatched( + case SolverType::C128: { + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZpotrfBatched( handle.get(), d.uplo, d.n, - static_cast(workspace), d.n, info, d.batch))); + static_cast(workspace), d.n, + reinterpret_cast( + static_cast(workspace) + d.batch), + d.lwork, info, d.batch))); break; } } @@ -173,7 +186,7 @@ static absl::Status Potrf_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Potrf(cudaStream_t stream, void** buffers, const char* opaque, +void Potrf(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Potrf_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -184,7 +197,7 @@ void Potrf(cudaStream_t stream, void** buffers, const char* opaque, // getrf: LU decomposition -static absl::Status Getrf_(cudaStream_t stream, void** buffers, +static absl::Status Getrf_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -193,59 +206,59 @@ static absl::Status Getrf_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( buffers[1], buffers[0], - SizeOfCusolverType(d.type) * static_cast(d.batch) * + SizeOfSolverType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream))); + gpuMemcpyDeviceToDevice, stream))); } int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); void* workspace = buffers[4]; switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnSgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), ipiv, info))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgetrf( + handle.get(), d.m, d.n, a, d.m, static_cast(workspace), + d.lwork, ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; } break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnDgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), ipiv, info))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgetrf( + handle.get(), d.m, d.n, a, d.m, static_cast(workspace), + d.lwork, ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; } break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnCgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), ipiv, info))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgetrf( + handle.get(), d.m, d.n, a, d.m, static_cast(workspace), + d.lwork, ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; } break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgetrf( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgetrf( handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), ipiv, info))); + static_cast(workspace), d.lwork, ipiv, info))); a += d.m * d.n; ipiv += std::min(d.m, d.n); ++info; @@ -256,7 +269,7 @@ static absl::Status Getrf_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Getrf(cudaStream_t stream, void** buffers, const char* opaque, +void Getrf(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Getrf_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -267,7 +280,7 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque, // geqrf: QR decomposition -static absl::Status Geqrf_(cudaStream_t stream, void** buffers, +static absl::Status Geqrf_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -276,62 +289,62 @@ static absl::Status Geqrf_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( buffers[1], buffers[0], - SizeOfCusolverType(d.type) * static_cast(d.batch) * + SizeOfSolverType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream))); + gpuMemcpyDeviceToDevice, stream))); } int* info = static_cast(buffers[3]); void* workspace = buffers[4]; switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); float* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + gpusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; } break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); double* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + gpusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; } break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); - cuComplex* tau = static_cast(buffers[2]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); + gpuComplex* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgeqrf( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgeqrf( handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; } break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); - cuDoubleComplex* tau = static_cast(buffers[2]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); + gpuDoubleComplex* tau = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgeqrf( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgeqrf( handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += std::min(d.m, d.n); ++info; @@ -342,7 +355,7 @@ static absl::Status Geqrf_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, +void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Geqrf_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -351,9 +364,11 @@ void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, } } +#ifdef JAX_GPU_CUDA + // csrlsvqr: Linear system solve via Sparse QR -static absl::Status Csrlsvqr_(cudaStream_t stream, void** buffers, +static absl::Status Csrlsvqr_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, int& singularity) { auto s = UnpackDescriptor(opaque, opaque_len); @@ -373,7 +388,7 @@ static absl::Status Csrlsvqr_(cudaStream_t stream, void** buffers, cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO))); switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* csrValA = static_cast(buffers[0]); int* csrRowPtrA = static_cast(buffers[1]); int* csrColIndA = static_cast(buffers[2]); @@ -381,12 +396,12 @@ static absl::Status Csrlsvqr_(cudaStream_t stream, void** buffers, float* x = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpScsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, - b, (float)d.tol, d.reorder, x, &singularity))); + handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, + (float)d.tol, d.reorder, x, &singularity))); break; } - case CusolverType::F64: { + case SolverType::F64: { double* csrValA = static_cast(buffers[0]); int* csrRowPtrA = static_cast(buffers[1]); int* csrColIndA = static_cast(buffers[2]); @@ -394,34 +409,34 @@ static absl::Status Csrlsvqr_(cudaStream_t stream, void** buffers, double* x = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpDcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, - b, d.tol, d.reorder, x, &singularity))); + handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, + d.tol, d.reorder, x, &singularity))); break; } - case CusolverType::C64: { - cuComplex* csrValA = static_cast(buffers[0]); + case SolverType::C64: { + gpuComplex* csrValA = static_cast(buffers[0]); int* csrRowPtrA = static_cast(buffers[1]); int* csrColIndA = static_cast(buffers[2]); - cuComplex* b = static_cast(buffers[3]); - cuComplex* x = static_cast(buffers[4]); + gpuComplex* b = static_cast(buffers[3]); + gpuComplex* x = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, - b, (float)d.tol, d.reorder, x, &singularity))); + handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, + (float)d.tol, d.reorder, x, &singularity))); break; } - case CusolverType::C128: { - cuDoubleComplex* csrValA = static_cast(buffers[0]); + case SolverType::C128: { + gpuDoubleComplex* csrValA = static_cast(buffers[0]); int* csrRowPtrA = static_cast(buffers[1]); int* csrColIndA = static_cast(buffers[2]); - cuDoubleComplex* b = static_cast(buffers[3]); - cuDoubleComplex* x = static_cast(buffers[4]); + gpuDoubleComplex* b = static_cast(buffers[3]); + gpuDoubleComplex* x = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpZcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, - b, (float)d.tol, d.reorder, x, &singularity))); + handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, + (float)d.tol, d.reorder, x, &singularity))); break; } @@ -431,7 +446,7 @@ static absl::Status Csrlsvqr_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Csrlsvqr(cudaStream_t stream, void** buffers, const char* opaque, +void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { // Is >= 0 if A is singular. int singularity = -1; @@ -448,9 +463,11 @@ void Csrlsvqr(cudaStream_t stream, void** buffers, const char* opaque, } } +#endif // JAX_GPU_CUDA + // orgqr/ungqr: apply elementary Householder transformations -static absl::Status Orgqr_(cudaStream_t stream, void** buffers, +static absl::Status Orgqr_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -459,62 +476,62 @@ static absl::Status Orgqr_(cudaStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; if (buffers[2] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( buffers[2], buffers[0], - SizeOfCusolverType(d.type) * static_cast(d.batch) * + SizeOfSolverType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream))); + gpuMemcpyDeviceToDevice, stream))); } int* info = static_cast(buffers[3]); void* workspace = buffers[4]; switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[2]); float* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + gpusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; } break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[2]); double* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + gpusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; } break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[2]); - cuComplex* tau = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[2]); + gpuComplex* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCungqr( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCungqr( handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; } break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[2]); - cuDoubleComplex* tau = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[2]); + gpuDoubleComplex* tau = static_cast(buffers[1]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZungqr( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZungqr( handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); + static_cast(workspace), d.lwork, info))); a += d.m * d.n; tau += d.k; ++info; @@ -525,7 +542,7 @@ static absl::Status Orgqr_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, +void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Orgqr_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -536,7 +553,7 @@ void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd -static absl::Status Syevd_(cudaStream_t stream, void** buffers, +static absl::Status Syevd_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -544,61 +561,61 @@ static absl::Status Syevd_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( buffers[1], buffers[0], - SizeOfCusolverType(d.type) * static_cast(d.batch) * + SizeOfSolverType(d.type) * static_cast(d.batch) * static_cast(d.n) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream))); - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; + gpuMemcpyDeviceToDevice, stream))); + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; int* info = static_cast(buffers[3]); void* work = buffers[4]; switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); + gpusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; } break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); + gpusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; } break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); + gpusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; } break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); + static_cast(work), d.lwork, info))); a += d.n * d.n; w += d.n; ++info; @@ -609,7 +626,7 @@ static absl::Status Syevd_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Syevd(cudaStream_t stream, void** buffers, const char* opaque, +void Syevd(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Syevd_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -621,7 +638,7 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque, // Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj // Supports batches of matrices up to size 32. -absl::Status Syevj_(cudaStream_t stream, void** buffers, const char* opaque, +absl::Status Syevj_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -630,88 +647,88 @@ absl::Status Syevj_(cudaStream_t stream, void** buffers, const char* opaque, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( buffers[1], buffers[0], - SizeOfCusolverType(d.type) * static_cast(d.batch) * + SizeOfSolverType(d.type) * static_cast(d.batch) * static_cast(d.n) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream))); + gpuMemcpyDeviceToDevice, stream))); } - syevjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); }); + gpuSyevjInfo_t params; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); + std::unique_ptr params_cleanup( + params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; int* info = static_cast(buffers[3]); void* work = buffers[4]; if (d.batch == 1) { switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj( handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast(work), d.lwork, info, params))); break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj( handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast(work), d.lwork, info, params))); break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); + static_cast(work), d.lwork, info, params))); break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZheevj( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); + static_cast(work), d.lwork, info, params))); break; } } } else { switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSsyevjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched( handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast(work), d.lwork, info, params, d.batch))); break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDsyevjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched( handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast(work), d.lwork, info, params, d.batch))); break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCheevjBatched( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched( handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); + static_cast(work), d.lwork, info, params, d.batch))); break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); double* w = static_cast(buffers[2]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), - d.lwork, info, params, d.batch))); + gpusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), + d.lwork, info, params, d.batch))); break; } } @@ -719,7 +736,7 @@ absl::Status Syevj_(cudaStream_t stream, void** buffers, const char* opaque, return absl::OkStatus(); } -void Syevj(cudaStream_t stream, void** buffers, const char* opaque, +void Syevj(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Syevj_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -730,7 +747,7 @@ void Syevj(cudaStream_t stream, void** buffers, const char* opaque, // Singular value decomposition using QR algorithm: gesvd -static absl::Status Gesvd_(cudaStream_t stream, void** buffers, +static absl::Status Gesvd_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -738,22 +755,22 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( buffers[1], buffers[0], - SizeOfCusolverType(d.type) * static_cast(d.batch) * + SizeOfSolverType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream))); + gpuMemcpyDeviceToDevice, stream))); int* info = static_cast(buffers[5]); void* work = buffers[6]; int64_t k = d.jobu == 'A' ? d.m : d.n; switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); float* s = static_cast(buffers[2]); float* u = static_cast(buffers[3]); float* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd( handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, static_cast(work), d.lwork, /*rwork=*/nullptr, info))); @@ -765,13 +782,13 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers, } break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); double* s = static_cast(buffers[2]); double* u = static_cast(buffers[3]); double* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd( handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, static_cast(work), d.lwork, /*rwork=*/nullptr, info))); @@ -783,15 +800,15 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers, } break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); float* s = static_cast(buffers[2]); - cuComplex* u = static_cast(buffers[3]); - cuComplex* vt = static_cast(buffers[4]); + gpuComplex* u = static_cast(buffers[3]); + gpuComplex* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd( handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, /*rwork=*/nullptr, info))); + static_cast(work), d.lwork, /*rwork=*/nullptr, info))); a += d.m * d.n; s += std::min(d.m, d.n); u += d.m * k; @@ -800,15 +817,15 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers, } break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); double* s = static_cast(buffers[2]); - cuDoubleComplex* u = static_cast(buffers[3]); - cuDoubleComplex* vt = static_cast(buffers[4]); + gpuDoubleComplex* u = static_cast(buffers[3]); + gpuDoubleComplex* vt = static_cast(buffers[4]); for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvd( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd( handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, + static_cast(work), d.lwork, /*rwork=*/nullptr, info))); a += d.m * d.n; s += std::min(d.m, d.n); @@ -822,7 +839,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, +void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Gesvd_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -831,9 +848,11 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, } } +#ifdef JAX_GPU_CUDA + // Singular value decomposition using Jacobi algorithm: gesvdj -static absl::Status Gesvdj_(cudaStream_t stream, void** buffers, +static absl::Status Gesvdj_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -841,11 +860,11 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers, auto h = SolverHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudaMemcpyAsync( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( buffers[1], buffers[0], - SizeOfCusolverType(d.type) * static_cast(d.batch) * + SizeOfSolverType(d.type) * static_cast(d.batch) * static_cast(d.m) * static_cast(d.n), - cudaMemcpyDeviceToDevice, stream))); + gpuMemcpyDeviceToDevice, stream))); int* info = static_cast(buffers[5]); void* work = buffers[6]; gesvdjInfo_t params; @@ -854,50 +873,50 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers, params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); if (d.batch == 1) { switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); float* s = static_cast(buffers[2]); float* u = static_cast(buffers[3]); float* v = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params))); + handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, + static_cast(work), d.lwork, info, params))); break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); double* s = static_cast(buffers[2]); double* u = static_cast(buffers[3]); double* v = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params))); + handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, + static_cast(work), d.lwork, info, params))); break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); float* s = static_cast(buffers[2]); - cuComplex* u = static_cast(buffers[3]); - cuComplex* v = static_cast(buffers[4]); + gpuComplex* u = static_cast(buffers[3]); + gpuComplex* v = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params))); + handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, + static_cast(work), d.lwork, info, params))); break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); double* s = static_cast(buffers[2]); - cuDoubleComplex* u = static_cast(buffers[3]); - cuDoubleComplex* v = static_cast(buffers[4]); + gpuDoubleComplex* u = static_cast(buffers[3]); + gpuDoubleComplex* v = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, - d.n, static_cast(work), d.lwork, info, params))); + handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, + static_cast(work), d.lwork, info, params))); break; } } } else { switch (d.type) { - case CusolverType::F32: { + case SolverType::F32: { float* a = static_cast(buffers[1]); float* s = static_cast(buffers[2]); float* u = static_cast(buffers[3]); @@ -907,7 +926,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers, static_cast(work), d.lwork, info, params, d.batch))); break; } - case CusolverType::F64: { + case SolverType::F64: { double* a = static_cast(buffers[1]); double* s = static_cast(buffers[2]); double* u = static_cast(buffers[3]); @@ -917,24 +936,24 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers, static_cast(work), d.lwork, info, params, d.batch))); break; } - case CusolverType::C64: { - cuComplex* a = static_cast(buffers[1]); + case SolverType::C64: { + gpuComplex* a = static_cast(buffers[1]); float* s = static_cast(buffers[2]); - cuComplex* u = static_cast(buffers[3]); - cuComplex* v = static_cast(buffers[4]); + gpuComplex* u = static_cast(buffers[3]); + gpuComplex* v = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched( handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); + static_cast(work), d.lwork, info, params, d.batch))); break; } - case CusolverType::C128: { - cuDoubleComplex* a = static_cast(buffers[1]); + case SolverType::C128: { + gpuDoubleComplex* a = static_cast(buffers[1]); double* s = static_cast(buffers[2]); - cuDoubleComplex* u = static_cast(buffers[3]); - cuDoubleComplex* v = static_cast(buffers[4]); + gpuDoubleComplex* u = static_cast(buffers[3]); + gpuDoubleComplex* v = static_cast(buffers[4]); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched( handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, + static_cast(work), d.lwork, info, params, d.batch))); break; } @@ -943,7 +962,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers, return absl::OkStatus(); } -void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, +void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = Gesvdj_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -952,4 +971,7 @@ void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, } } +#endif // JAX_GPU_CUDA + +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/cuda/cusolver_kernels.h b/jaxlib/gpu/solver_kernels.h similarity index 67% rename from jaxlib/cuda/cusolver_kernels.h rename to jaxlib/gpu/solver_kernels.h index 3bc0a097e..b3498acb2 100644 --- a/jaxlib/cuda/cusolver_kernels.h +++ b/jaxlib/gpu/solver_kernels.h @@ -17,28 +17,36 @@ limitations under the License. #define JAXLIB_CUSOLVER_KERNELS_H_ #include "absl/status/statusor.h" -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "third_party/gpus/cuda/include/cusolverDn.h" -#include "third_party/gpus/cuda/include/cusolverSp.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/handle_pool.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" +#ifdef JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cusolverSp.h" +#endif // JAX_GPU_CUDA + namespace jax { -using SolverHandlePool = HandlePool; -using SpSolverHandlePool = HandlePool; +using SolverHandlePool = HandlePool; template <> absl::StatusOr SolverHandlePool::Borrow( - cudaStream_t stream); + gpuStream_t stream); + +#ifdef JAX_GPU_CUDA + +using SpSolverHandlePool = HandlePool; template <> absl::StatusOr SpSolverHandlePool::Borrow( - cudaStream_t stream); + gpuStream_t stream); + +#endif // JAX_GPU_CUDA + +namespace JAX_GPU_NAMESPACE { // Set of types known to Cusolver. -enum class CusolverType { +enum class SolverType { F32, F64, C64, @@ -48,105 +56,114 @@ enum class CusolverType { // potrf: Cholesky decomposition struct PotrfDescriptor { - CusolverType type; - cublasFillMode_t uplo; + SolverType type; + gpusolverFillMode_t uplo; std::int64_t batch, n; int lwork; }; -void Potrf(cudaStream_t stream, void** buffers, const char* opaque, +void Potrf(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // getrf: LU decomposition struct GetrfDescriptor { - CusolverType type; - int batch, m, n; + SolverType type; + int batch, m, n, lwork; }; -void Getrf(cudaStream_t stream, void** buffers, const char* opaque, +void Getrf(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // geqrf: QR decomposition struct GeqrfDescriptor { - CusolverType type; + SolverType type; int batch, m, n, lwork; }; -void Geqrf(cudaStream_t stream, void** buffers, const char* opaque, +void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +#ifdef JAX_GPU_CUDA + // csrlsvpr: Linear system solve via Sparse QR struct CsrlsvqrDescriptor { - CusolverType type; + SolverType type; int n, nnz, reorder; double tol; }; -void Csrlsvqr(cudaStream_t stream, void** buffers, const char* opaque, +void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +#endif // JAX_GPU_CUDA + // orgqr/ungqr: apply elementary Householder transformations struct OrgqrDescriptor { - CusolverType type; + SolverType type; int batch, m, n, k, lwork; }; -void Orgqr(cudaStream_t stream, void** buffers, const char* opaque, +void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd struct SyevdDescriptor { - CusolverType type; - cublasFillMode_t uplo; + SolverType type; + gpusolverFillMode_t uplo; int batch, n; int lwork; }; -void Syevd(cudaStream_t stream, void** buffers, const char* opaque, +void Syevd(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj // Supports batches of matrices up to size 32. struct SyevjDescriptor { - CusolverType type; - cublasFillMode_t uplo; + SolverType type; + gpusolverFillMode_t uplo; int batch, n; int lwork; }; -void Syevj(cudaStream_t stream, void** buffers, const char* opaque, +void Syevj(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // Singular value decomposition using QR algorithm: gesvd struct GesvdDescriptor { - CusolverType type; + SolverType type; int batch, m, n; int lwork; signed char jobu, jobvt; }; -void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, +void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +#ifdef JAX_GPU_CUDA + // Singular value decomposition using Jacobi algorithm: gesvdj struct GesvdjDescriptor { - CusolverType type; + SolverType type; int batch, m, n; int lwork; - cusolverEigMode_t jobz; + gpusolverEigMode_t jobz; int econ; }; -void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque, +void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +#endif // JAX_GPU_CUDA + +} // namespace JAX_GPU_NAMESPACE } // namespace jax #endif // JAXLIB_CUSOLVER_KERNELS_H_ diff --git a/jaxlib/rocm/hipsparse.cc b/jaxlib/gpu/sparse.cc similarity index 54% rename from jaxlib/rocm/hipsparse.cc rename to jaxlib/gpu/sparse.cc index 57f3cbb07..1c6f5c1fe 100644 --- a/jaxlib/rocm/hipsparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "rocm/include/hipsparse.h" - #include #include #include @@ -26,10 +24,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "rocm/include/hip/hip_complex.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" -#include "jaxlib/rocm/hipsparse_kernels.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/sparse_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_pybind11_helpers.h" #include "include/pybind11/numpy.h" #include "include/pybind11/pybind11.h" @@ -38,14 +35,15 @@ limitations under the License. namespace py = pybind11; namespace jax { +namespace JAX_GPU_NAMESPACE { namespace { -hipsparseIndexType_t DtypeToHipSparseIndexType(const py::dtype& np_type) { +gpusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) { static auto* types = - new absl::flat_hash_map, hipsparseIndexType_t>({ - {{'u', 2}, HIPSPARSE_INDEX_16U}, - {{'i', 4}, HIPSPARSE_INDEX_32I}, - {{'i', 8}, HIPSPARSE_INDEX_64I}, + new absl::flat_hash_map, gpusparseIndexType_t>({ + {{'u', 2}, GPUSPARSE_INDEX_16U}, + {{'i', 4}, GPUSPARSE_INDEX_32I}, + {{'i', 8}, GPUSPARSE_INDEX_64I}, }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { @@ -55,16 +53,20 @@ hipsparseIndexType_t DtypeToHipSparseIndexType(const py::dtype& np_type) { return it->second; } -// TODO(rocm): add more hip data types when supported -hipDataType DtypeToHipDataType(const py::dtype& np_type) { +gpuDataType DtypeToCudaDataType(const py::dtype& np_type) { static auto* types = - new absl::flat_hash_map, hipDataType>( - {{{'f', 2}, HIP_R_16F}, - {{'c', 4}, HIP_C_16F}, - {{'f', 4}, HIP_R_32F}, - {{'c', 8}, HIP_C_32F}, - {{'f', 8}, HIP_R_64F}, - {{'c', 16}, HIP_C_64F}}); + new absl::flat_hash_map, gpuDataType>({ + {{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F}, + {{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F}, + {{'c', 16}, GPU_C_64F}, +#ifdef JAX_GPU_CUDA + {{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U}, + {{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U}, +#if JAX_GPU_HAVE_SPARSE + {{'V', 2}, CUDA_R_16BF}, +#endif // JAX_GPU_HAVE_SPARSE +#endif // JAX_GPU_CUDA + }); auto it = types->find({np_type.kind(), np_type.itemsize()}); if (it == types->end()) { throw std::invalid_argument( @@ -78,28 +80,28 @@ SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype, int rows, int cols, int nnz, int batch_count, int batch_stride) { - hipDataType value_type = DtypeToHipDataType(data_dtype); - hipsparseIndexType_t index_type = DtypeToHipSparseIndexType(index_dtype); - return SparseMatDescriptor{ - value_type, index_type, rows, cols, nnz, batch_count, batch_stride}; + gpuDataType value_type = DtypeToCudaDataType(data_dtype); + gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype); + return SparseMatDescriptor{value_type, index_type, rows, cols, + nnz, batch_count, batch_stride}; } // Returns the descriptor for a Dense matrix. DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype, int rows, int cols, int batch_count, int batch_stride) { - hipDataType value_type = DtypeToHipDataType(data_dtype); + gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride}; } // Returns the descriptor for a Dense vector. DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype, int size) { - hipDataType value_type = DtypeToHipDataType(data_dtype); + gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseVecDescriptor{value_type, size}; } - +#if JAX_GPU_HAVE_SPARSE // CsrToDense: Convert CSR matrix to dense matrix // Returns the descriptor for a Sparse matrix. @@ -111,35 +113,35 @@ std::pair BuildCsrToDenseDescriptor( auto& handle = *h; SparseMatDescriptor d = BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - /*batch_count*/1, /*batch_stride*/0); + /*batch_count*/ 1, /*batch_stride*/ 0); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; // buffer_size does not reference these pointers, but does error on NULL. // TODO(jakevdp): check whether this is documented. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr( &mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, - d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_b, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW))); size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize( - handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize( + handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrToDense_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { +absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); const SparseMatDescriptor& d = **s; @@ -147,28 +149,28 @@ absl::Status CsrToDense_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[2], - /*csrColInd=*/buffers[1], - /*csrValues=*/buffers[0], d.index_type, d.index_type, - HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, + /*csrRowOffsets=*/buffers[2], + /*csrColInd=*/buffers[1], + /*csrValues=*/buffers[0], d.index_type, d.index_type, + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseSparseToDense(handle.get(), mat_a, mat_b, - HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); + gpusparseSparseToDense(handle.get(), mat_a, mat_b, + GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); return absl::OkStatus(); } -void CsrToDense(hipStream_t stream, void** buffers, const char* opaque, +void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CsrToDense_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -190,30 +192,30 @@ std::pair BuildCsrFromDenseDescriptor( BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, /*batch_count=*/1, /*batch_stride=*/0); - hipsparseDnMatDescr_t mat_a = 0; - hipsparseSpMatDescr_t mat_b = 0; + gpusparseDnMatDescr_t mat_a = 0; + gpusparseSpMatDescr_t mat_b = 0; // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_a, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr( + /*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr( &mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, - d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); + d.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrFromDense_(hipStream_t stream, void** buffers, +absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -222,29 +224,29 @@ absl::Status CsrFromDense_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - hipsparseDnMatDescr_t mat_a = 0; - hipsparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + gpusparseDnMatDescr_t mat_a = 0; + gpusparseSpMatDescr_t mat_b = 0; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[3], - /*csrColInd=*/buffers[2], - /*csrValues=*/buffers[1], d.index_type, d.index_type, - HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, + /*csrRowOffsets=*/buffers[3], + /*csrColInd=*/buffers[2], + /*csrValues=*/buffers[1], d.index_type, d.index_type, + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); return absl::OkStatus(); } -void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque, +void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -271,32 +273,32 @@ std::pair BuildCsrMatvecDescriptor( DenseVecDescriptor y = BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnVecDescr_t vec_x = 0; - hipsparseDnVecDescr_t vec_y = 0; - hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE - : HIPSPARSE_OPERATION_NON_TRANSPOSE; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnVecDescr_t vec_x = 0; + gpusparseDnVecDescr_t vec_y = 0; + gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE + : GPUSPARSE_OPERATION_NON_TRANSPOSE; // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr( &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, - A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type))); + A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - HipConst alpha = HipOne(y.type); - HipConst beta = HipZero(y.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize( + SparseConst alpha = ConstOne(y.type); + SparseConst beta = ConstZero(y.type); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - HIPSPARSE_MV_ALG_DEFAULT, &buffer_size))); + GPUSPARSE_MV_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y))); return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})}; } @@ -320,35 +322,35 @@ std::pair BuildCsrMatmatDescriptor( DenseMatDescriptor C = BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols, /*batch_count=*/1, /*batch_stride=*/0); - hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE - : HIPSPARSE_OPERATION_NON_TRANSPOSE; + gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE + : GPUSPARSE_OPERATION_NON_TRANSPOSE; - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; - hipsparseDnMatDescr_t mat_c = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; + gpusparseDnMatDescr_t mat_c = 0; // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr( &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, - A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type))); + A.index_type, GPUSPARSE_INDEX_BASE_ZERO, A.value_type))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, - empty, B.type, HIPSPARSE_ORDER_ROW))); + JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, + empty, B.type, GPUSPARSE_ORDER_ROW))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, - empty, C.type, HIPSPARSE_ORDER_ROW))); + JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, + empty, C.type, GPUSPARSE_ORDER_ROW))); size_t buffer_size; - HipConst alpha = HipOne(C.type); - HipConst beta = HipZero(C.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize( - handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, - mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); + SparseConst alpha = ConstOne(C.type); + SparseConst beta = ConstZero(C.type); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( + handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, + mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c))); return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})}; } @@ -366,26 +368,26 @@ std::pair BuildCooToDenseDescriptor( BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, /*batch_count=*/1, /*batch_stride=*/0); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, - d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo( + &mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_b, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW))); size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize( - handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSparseToDense_bufferSize( + handle.get(), mat_a, mat_b, GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } @@ -403,25 +405,25 @@ std::pair BuildCooFromDenseDescriptor( BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, /*batch_count=*/1, /*batch_stride=*/0); - hipsparseDnMatDescr_t mat_a = 0; - hipsparseSpMatDescr_t mat_b = 0; + gpusparseDnMatDescr_t mat_a = 0; + gpusparseSpMatDescr_t mat_b = 0; // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_a, d.rows, d.cols, - /*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, - d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); + /*ld=*/d.cols, empty, d.value_type, GPUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo( + &mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type, + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); size_t buffer_size; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_bufferSize( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); return {buffer_size, PackDescriptor(d)}; } @@ -444,32 +446,32 @@ std::pair BuildCooMatvecDescriptor( DenseVecDescriptor y = BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnVecDescr_t vec_x = 0; - hipsparseDnVecDescr_t vec_y = 0; - hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE - : HIPSPARSE_OPERATION_NON_TRANSPOSE; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnVecDescr_t vec_x = 0; + gpusparseDnVecDescr_t vec_y = 0; + gpusparseOperation_t op = transpose ? GPUSPARSE_OPERATION_TRANSPOSE + : GPUSPARSE_OPERATION_NON_TRANSPOSE; // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, - A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo( + &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, + GPUSPARSE_INDEX_BASE_ZERO, A.value_type))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, x.size, empty, x.type))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - HipConst alpha = HipOne(y.type); - HipConst beta = HipZero(y.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize( + SparseConst alpha = ConstOne(y.type); + SparseConst beta = ConstZero(y.type); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, - HIPSPARSE_MV_ALG_DEFAULT, &buffer_size))); + GPUSPARSE_MV_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y))); return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})}; } @@ -490,63 +492,61 @@ std::pair BuildCooMatmatDescriptor( auto h = SparseHandlePool::Borrow(); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; - SparseMatDescriptor A = - BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz, - batch_count, lhs_batch_stride); - DenseMatDescriptor B = - BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols, - batch_count, rhs_batch_stride); + + SparseMatDescriptor A = BuildSparseMatDescriptor( + data_dtype, index_dtype, rows, cols, nnz, batch_count, lhs_batch_stride); + DenseMatDescriptor B = BuildDenseMatDescriptor( + b_dtype, transpose ? rows : cols, BCcols, batch_count, rhs_batch_stride); int C_rows = (transpose == true) ? cols : rows; // TODO(tianjianlu): enable the selection of batch stride. - // The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643) + // The issue + // (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643) // in cusparse library does not allow batch_stride = 0. // int C_batch_stride = (batch_count > 1)? C_rows * BCcols : 0; int C_batch_stride = C_rows * BCcols; DenseMatDescriptor C = BuildDenseMatDescriptor(compute_dtype, /*rows=*/C_rows, /*cols=*/BCcols, batch_count, C_batch_stride); - hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE - : HIPSPARSE_OPERATION_NON_TRANSPOSE; + gpusparseOperation_t op_A = transpose ? GPUSPARSE_OPERATION_TRANSPOSE + : GPUSPARSE_OPERATION_NON_TRANSPOSE; - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; - hipsparseDnMatDescr_t mat_c = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; + gpusparseDnMatDescr_t mat_c = 0; // bufferSize does not reference these pointers, but does error on NULL. int val = 0; void* empty = &val; - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, - A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsparseCooSetStridedBatch( - mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo( + &mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type, + GPUSPARSE_INDEX_BASE_ZERO, A.value_type))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseCooSetStridedBatch( + mat_a, /*batchCount=*/batch_count, /*batchStride=*/A.batch_stride))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, - empty, B.type, HIPSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsparseDnMatSetStridedBatch( - mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride))); + JAX_AS_STATUS(gpusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols, + empty, B.type, GPUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch( + mat_b, /*batchCount=*/batch_count, /*batchStride=*/B.batch_stride))); JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, - empty, C.type, HIPSPARSE_ORDER_ROW))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsparseDnMatSetStridedBatch( - mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride))); + JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, + empty, C.type, GPUSPARSE_ORDER_ROW))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch( + mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride))); size_t buffer_size; - HipConst alpha = HipOne(C.type); - HipConst beta = HipZero(C.type); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize( - handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, - mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); + SparseConst alpha = ConstOne(C.type); + SparseConst beta = ConstZero(C.type); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( + handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, + mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); + JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c))); return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})}; } +#endif // if JAX_GPU_HAVE_SPARSE py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) { return PackDescriptor(Gtsv2Descriptor{m, n, ldb}); @@ -565,32 +565,37 @@ size_t Gtsv2BufferSize(F f, int m, int n, int ldb) { } size_t Gtsv2BufferSizeF32(int m, int n, int ldb) { - return Gtsv2BufferSize(hipsparseSgtsv2_bufferSizeExt, m, n, ldb); + return Gtsv2BufferSize(gpusparseSgtsv2_bufferSizeExt, m, n, ldb); } size_t Gtsv2BufferSizeF64(int m, int n, int ldb) { - return Gtsv2BufferSize(hipsparseDgtsv2_bufferSizeExt, m, n, ldb); + return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb); } py::dict Registrations() { py::dict dict; - dict["hipsparse_csr_todense"] = EncapsulateFunction(CsrToDense); - dict["hipsparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense); - dict["hipsparse_csr_matvec"] = EncapsulateFunction(CsrMatvec); - dict["hipsparse_csr_matmat"] = EncapsulateFunction(CsrMatmat); - dict["hipsparse_coo_todense"] = EncapsulateFunction(CooToDense); - dict["hipsparse_coo_fromdense"] = EncapsulateFunction(CooFromDense); - dict["hipsparse_coo_matvec"] = EncapsulateFunction(CooMatvec); - dict["hipsparse_coo_matmat"] = EncapsulateFunction(CooMatmat); - dict["hipsparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32); - dict["hipsparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64); +#if JAX_GPU_HAVE_SPARSE + dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense); + dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] = + EncapsulateFunction(CsrFromDense); + dict[JAX_GPU_PREFIX "sparse_csr_matvec"] = EncapsulateFunction(CsrMatvec); + dict[JAX_GPU_PREFIX "sparse_csr_matmat"] = EncapsulateFunction(CsrMatmat); + dict[JAX_GPU_PREFIX "sparse_coo_todense"] = EncapsulateFunction(CooToDense); + dict[JAX_GPU_PREFIX "sparse_coo_fromdense"] = + EncapsulateFunction(CooFromDense); + dict[JAX_GPU_PREFIX "sparse_coo_matvec"] = EncapsulateFunction(CooMatvec); + dict[JAX_GPU_PREFIX "sparse_coo_matmat"] = EncapsulateFunction(CooMatmat); +#endif + dict[JAX_GPU_PREFIX "sparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32); + dict[JAX_GPU_PREFIX "sparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64); // TODO(tomhennigan): Add support for gtsv2 complex 32/64. return dict; } -PYBIND11_MODULE(_hipsparse, m) { - m.attr("hipsparse_supported") = py::bool_(true); +PYBIND11_MODULE(_sparse, m) { + m.attr("sparse_supported") = py::bool_(JAX_GPU_HAVE_SPARSE); m.def("registrations", &Registrations); +#if JAX_GPU_HAVE_SPARSE m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor); m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor); m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor); @@ -599,10 +604,12 @@ PYBIND11_MODULE(_hipsparse, m) { m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor); m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor); m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor); +#endif m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32); m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64); m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor); } } // namespace +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/hipsparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc similarity index 57% rename from jaxlib/rocm/hipsparse_kernels.cc rename to jaxlib/gpu/sparse_kernels.cc index b9e061c9d..9f07a7c90 100644 --- a/jaxlib/rocm/hipsparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/rocm/hipsparse_kernels.h" +#include "jaxlib/gpu/sparse_kernels.h" #include #include @@ -24,62 +24,128 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/handle_pool.h" -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" #include "jaxlib/kernel_helpers.h" -#include "rocm/include/hip/hip_complex.h" -#include "rocm/include/hip/hip_runtime_api.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" namespace jax { template <> -/*static*/ absl::StatusOr -SparseHandlePool::Borrow(hipStream_t stream) { +/*static*/ absl::StatusOr SparseHandlePool::Borrow( + gpuStream_t stream) { SparseHandlePool* pool = Instance(); absl::MutexLock lock(&pool->mu_); - hipsparseHandle_t handle; + gpusparseHandle_t handle; if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreate(&handle))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreate(&handle))); } else { handle = pool->handles_[stream].back(); pool->handles_[stream].pop_back(); } if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSetStream(handle, stream))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSetStream(handle, stream))); } return Handle(pool, handle, stream); } -HipConst HipZero(hipDataType type) { - HipConst c; +namespace JAX_GPU_NAMESPACE { + +SparseConst ConstZero(gpuDataType type) { + SparseConst c; std::memset(&c, 0, sizeof(c)); return c; } -HipConst HipOne(hipDataType type) { - HipConst c; +SparseConst ConstOne(gpuDataType type) { + SparseConst c; std::memset(&c, 0, sizeof(c)); - // TODO(rocm): add more data type if new rocm support switch (type) { +#ifdef JAX_GPU_CUDA +#if JAX_GPU_HAVE_SPARSE + // TODO(jakevdp): 4I/4U here might break on big endian platforms. + case CUDA_R_4I: + case CUDA_C_4I: +#endif + case CUDA_R_8I: + case CUDA_C_8I: + c.i8[0] = 1; + break; +#if JAX_GPU_HAVE_SPARSE + case CUDA_R_4U: + case CUDA_C_4U: +#endif + case CUDA_R_8U: + case CUDA_C_8U: + c.u8[0] = 1; + break; +#if JAX_GPU_HAVE_SPARSE + case CUDA_R_16I: + case CUDA_C_16I: + c.i16[0] = 1; + break; + case CUDA_R_16U: + case CUDA_C_16U: + c.u16[0] = 1; + break; +#endif + case CUDA_R_32I: + case CUDA_C_32I: + c.i32[0] = 1; + break; + case CUDA_R_32U: + case CUDA_C_32U: + c.u32[0] = 1; + break; +#if JAX_GPU_HAVE_SPARSE + case CUDA_R_64I: + case CUDA_C_64I: + c.i64[0] = 1; + break; + case CUDA_R_64U: + case CUDA_C_64U: + c.u64[0] = 1; + break; +#endif +#if JAX_CUDA_11080 + case CUDA_R_8F_E4M3: + c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E4M3); + break; + case CUDA_R_8F_E5M2: + c.u8[0] = __nv_cvt_float_to_fp8(1.0f, __NV_NOSAT, __NV_E5M2); + break; +#endif +#if JAX_GPU_HAVE_SPARSE + case CUDA_R_16BF: + case CUDA_C_16BF: + c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16 + break; +#endif +#endif // JAX_GPU_CUDA + // TODO(rocm): add more data types if new rocm supports them. + // TODO(jakevdp): 16F/16BF here might break on big endian platforms. - case HIP_R_16F: - case HIP_C_16F: + case GPU_R_16F: + case GPU_C_16F: c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16 break; - case HIP_R_32F: - case HIP_C_32F: + case GPU_R_32F: + case GPU_C_32F: c.f32[0] = 1.0; break; - case HIP_R_64F: - case HIP_C_64F: + case GPU_R_64F: + case GPU_C_64F: c.f64[0] = 1.0; break; } return c; } -static absl::Status CsrToDense_(hipStream_t stream, void** buffers, +#if JAX_GPU_HAVE_SPARSE +// CsrToDense: Convert CSR matrix to dense matrix + +static absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -88,28 +154,28 @@ static absl::Status CsrToDense_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, + gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, /*csrRowOffsets=*/buffers[2], /*csrColInd=*/buffers[1], /*csrValues=*/buffers[0], d.index_type, d.index_type, - HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseSparseToDense(handle.get(), mat_a, mat_b, - HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); + gpusparseSparseToDense(handle.get(), mat_a, mat_b, + GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); return absl::OkStatus(); } -void CsrToDense(hipStream_t stream, void** buffers, const char* opaque, +void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CsrToDense_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -120,7 +186,7 @@ void CsrToDense(hipStream_t stream, void** buffers, const char* opaque, // CsrFromDense: Convert dense matrix to CSR matrix -static absl::Status CsrFromDense_(hipStream_t stream, void** buffers, +static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -129,29 +195,29 @@ static absl::Status CsrFromDense_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - hipsparseDnMatDescr_t mat_a = 0; - hipsparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + gpusparseDnMatDescr_t mat_a = 0; + gpusparseSpMatDescr_t mat_b = 0; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, + gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, /*csrRowOffsets=*/buffers[3], /*csrColInd=*/buffers[2], /*csrValues=*/buffers[1], d.index_type, d.index_type, - HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); return absl::OkStatus(); } -void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque, +void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -162,7 +228,7 @@ void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque, // CsrMatvec: Product of CSR matrix and dense vector. -static absl::Status CsrMatvec_(hipStream_t stream, void** buffers, +static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -178,38 +244,37 @@ static absl::Status CsrMatvec_(hipStream_t stream, void** buffers, void* ybuf = buffers[4]; void* buf = buffers[5]; - // TODO(rocm): check the following statement for rocm // TODO(jakevdp): alpha and beta should be user-specifiable, but constants // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - HipConst alpha = HipOne(d.y.type); - HipConst beta = HipZero(d.y.type); + SparseConst alpha = ConstOne(d.y.type); + SparseConst beta = ConstZero(d.y.type); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnVecDescr_t vec_x = 0; - hipsparseDnVecDescr_t vec_y = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnVecDescr_t vec_x = 0; + gpusparseDnVecDescr_t vec_y = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr( &mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind, - csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, + csr_values, d.A.index_type, d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf))); + gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, + d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y))); return absl::OkStatus(); } -void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque, +void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CsrMatvec_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -220,7 +285,7 @@ void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque, // CsrMatmat: Product of CSR matrix and dense matrix. -static absl::Status CsrMatmat_(hipStream_t stream, void** buffers, +static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -240,34 +305,34 @@ static absl::Status CsrMatmat_(hipStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - HipConst alpha = HipOne(d.C.type); - HipConst beta = HipZero(d.C.type); + SparseConst alpha = ConstOne(d.C.type); + SparseConst beta = ConstZero(d.C.type); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; - hipsparseDnMatDescr_t mat_c = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; + gpusparseDnMatDescr_t mat_c = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCsr( &mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind, - csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, + csr_values, d.A.index_type, d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_b, d.B.rows, d.B.cols, - /*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + /*ld=*/d.B.cols, Bbuf, d.B.type, GPUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_c, d.C.rows, d.C.cols, - /*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM( - handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, - mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf))); + /*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM( + handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, + mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c))); return absl::OkStatus(); } -void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque, +void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CsrMatmat_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -278,7 +343,7 @@ void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque, // CooToDense: Convert COO matrix to dense matrix -static absl::Status CooToDense_(hipStream_t stream, void** buffers, +static absl::Status CooToDense_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -287,28 +352,28 @@ static absl::Status CooToDense_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, + gpusparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, /*cooRowInd=*/buffers[1], /*cooColInd=*/buffers[2], /*cooValues=*/buffers[0], d.index_type, - HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseSparseToDense(handle.get(), mat_a, mat_b, - HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); + gpusparseSparseToDense(handle.get(), mat_a, mat_b, + GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); return absl::OkStatus(); } -void CooToDense(hipStream_t stream, void** buffers, const char* opaque, +void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CooToDense_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -319,7 +384,7 @@ void CooToDense(hipStream_t stream, void** buffers, const char* opaque, // CooFromDense: Convert dense matrix to COO matrix -static absl::Status CooFromDense_(hipStream_t stream, void** buffers, +static absl::Status CooFromDense_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -328,29 +393,29 @@ static absl::Status CooFromDense_(hipStream_t stream, void** buffers, JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; - hipsparseDnMatDescr_t mat_a = 0; - hipsparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + gpusparseDnMatDescr_t mat_a = 0; + gpusparseSpMatDescr_t mat_b = 0; + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, + gpusparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, /*cooRowInd=*/buffers[2], /*cooColInd=*/buffers[3], /*cooValues=*/buffers[1], d.index_type, - HIPSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT, + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert( + handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); return absl::OkStatus(); } -void CooFromDense(hipStream_t stream, void** buffers, const char* opaque, +void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CooFromDense_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -361,7 +426,7 @@ void CooFromDense(hipStream_t stream, void** buffers, const char* opaque, // CooMatvec: Product of COO matrix and dense vector. -static absl::Status CooMatvec_(hipStream_t stream, void** buffers, +static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -377,37 +442,36 @@ static absl::Status CooMatvec_(hipStream_t stream, void** buffers, void* ybuf = buffers[4]; void* buf = buffers[5]; - // TODO(rocm): check the following statement for rocm // TODO(jakevdp): alpha and beta should be user-specifiable, but constants // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - HipConst alpha = HipOne(d.y.type); - HipConst beta = HipZero(d.y.type); + SparseConst alpha = ConstOne(d.y.type); + SparseConst beta = ConstZero(d.y.type); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnVecDescr_t vec_x = 0; - hipsparseDnVecDescr_t vec_y = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnVecDescr_t vec_x = 0; + gpusparseDnVecDescr_t vec_y = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo( &mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values, - d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type))); + d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type))); JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); + JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, - d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf))); + gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y, + d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_y))); return absl::OkStatus(); } -void CooMatvec(hipStream_t stream, void** buffers, const char* opaque, +void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CooMatvec_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -418,7 +482,7 @@ void CooMatvec(hipStream_t stream, void** buffers, const char* opaque, // CooMatmat: Product of COO matrix and dense matrix. -static absl::Status CooMatmat_(hipStream_t stream, void** buffers, +static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); @@ -434,47 +498,46 @@ static absl::Status CooMatmat_(hipStream_t stream, void** buffers, void* Cbuf = buffers[4]; void* buf = buffers[5]; - // TODO(rocm): check the following statement for rocm // TODO(jakevdp): alpha and beta should be user-specifiable, but constants // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - HipConst alpha = HipOne(d.C.type); - HipConst beta = HipZero(d.C.type); + SparseConst alpha = ConstOne(d.C.type); + SparseConst beta = ConstZero(d.C.type); - hipsparseSpMatDescr_t mat_a = 0; - hipsparseDnMatDescr_t mat_b = 0; - hipsparseDnMatDescr_t mat_c = 0; + gpusparseSpMatDescr_t mat_a = 0; + gpusparseDnMatDescr_t mat_b = 0; + gpusparseDnMatDescr_t mat_c = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo( + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateCoo( &mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values, - d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type))); + d.A.index_type, GPUSPARSE_INDEX_BASE_ZERO, d.A.value_type))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count, - /*batchStride=*/d.A.batch_stride))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + gpusparseCooSetStridedBatch(mat_a, /*batchCount=*/d.A.batch_count, + /*batchStride=*/d.A.batch_stride))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_b, d.B.rows, d.B.cols, - /*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.B.cols, Bbuf, d.B.type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count, - /*batchStride=*/d.B.batch_stride))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat( + gpusparseDnMatSetStridedBatch(mat_b, /*batchCount=*/d.B.batch_count, + /*batchStride=*/d.B.batch_stride))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( &mat_c, d.C.rows, d.C.cols, - /*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW))); + /*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count, - /*batchStride=*/d.C.batch_stride))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM( - handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, - mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf))); + gpusparseDnMatSetStridedBatch(mat_c, /*batchCount=*/d.C.batch_count, + /*batchStride=*/d.C.batch_stride))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM( + handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, + mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_c))); return absl::OkStatus(); } -void CooMatmat(hipStream_t stream, void** buffers, const char* opaque, +void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { auto s = CooMatmat_(stream, buffers, opaque, opaque_len); if (!s.ok()) { @@ -482,9 +545,10 @@ void CooMatmat(hipStream_t stream, void** buffers, const char* opaque, s.message().length()); } } +#endif // if JAX_GPU_HAVE_SPARSE template -static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers, +static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { auto h = SparseHandlePool::Borrow(); JAX_RETURN_IF_ERROR(h.status()); @@ -513,7 +577,7 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers, if (X != B) { size_t B_bytes = ldb * n * sizeof(T); JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipMemcpyAsync(X, B, B_bytes, hipMemcpyDeviceToDevice, stream))); + gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); } JAX_RETURN_IF_ERROR(JAX_AS_STATUS( @@ -521,22 +585,23 @@ static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers, return absl::OkStatus(); } -void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque, +void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(hipsparseSgtsv2, stream, buffers, opaque, opaque_len); + auto s = gtsv2(gpusparseSgtsv2, stream, buffers, opaque, opaque_len); if (!s.ok()) { XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), s.message().length()); } } -void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque, +void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(hipsparseDgtsv2, stream, buffers, opaque, opaque_len); + auto s = gtsv2(gpusparseDgtsv2, stream, buffers, opaque, opaque_len); if (!s.ok()) { XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), s.message().length()); } } +} // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/rocm/hipsparse_kernels.h b/jaxlib/gpu/sparse_kernels.h similarity index 67% rename from jaxlib/rocm/hipsparse_kernels.h rename to jaxlib/gpu/sparse_kernels.h index eef866a26..8880119b7 100644 --- a/jaxlib/rocm/hipsparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_HIPSPARSE_KERNELS_H_ -#define JAXLIB_HIPSPARSE_KERNELS_H_ +#ifndef JAXLIB_GPU_SPARSE_KERNELS_H_ +#define JAXLIB_GPU_SPARSE_KERNELS_H_ #include #include @@ -23,23 +23,21 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/handle_pool.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipsparse.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" -// Some functionality defined here is only available in CUSPARSE 11.3 or newer. -#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300) - namespace jax { -using SparseHandlePool = HandlePool; +using SparseHandlePool = HandlePool; template <> -/*static*/ absl::StatusOr -SparseHandlePool::Borrow(hipStream_t stream); +/*static*/ absl::StatusOr SparseHandlePool::Borrow( + gpuStream_t stream); -union HipConst { +namespace JAX_GPU_NAMESPACE { + +union SparseConst { int8_t i8[2]; int16_t i16[2]; int32_t i32[2]; @@ -52,37 +50,38 @@ union HipConst { double f64[2]; }; -HipConst HipZero(hipDataType type); -HipConst HipOne(hipDataType type); +SparseConst ConstZero(gpuDataType type); +SparseConst ConstOne(gpuDataType type); struct SparseMatDescriptor { - hipDataType value_type; - hipsparseIndexType_t index_type; + gpuDataType value_type; + gpusparseIndexType_t index_type; int rows, cols, nnz; int batch_count = 1; int batch_stride = 0; }; struct DenseMatDescriptor { - hipDataType type; + gpuDataType type; int rows, cols; int batch_count = 1; int batch_stride = 0; }; struct DenseVecDescriptor { - hipDataType type; + gpuDataType type; int size; }; +#if JAX_GPU_HAVE_SPARSE // CsrToDense: Convert CSR matrix to dense matrix -void CsrToDense(hipStream_t stream, void** buffers, const char* opaque, +void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // CsrFromDense: Convert dense matrix to CSR matrix -void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque, +void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // CsrMatvec: Product of CSR matrix and dense vector. @@ -90,10 +89,10 @@ void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque, struct CsrMatvecDescriptor { SparseMatDescriptor A; DenseVecDescriptor x, y; - hipsparseOperation_t op; + gpusparseOperation_t op; }; -void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque, +void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // CsrMatmat: Product of CSR matrix and dense matrix. @@ -101,20 +100,20 @@ void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque, struct CsrMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; - hipsparseOperation_t op_A; + gpusparseOperation_t op_A; }; -void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque, +void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // CooToDense: Convert COO matrix to dense matrix -void CooToDense(hipStream_t stream, void** buffers, const char* opaque, +void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // CooFromDense: Convert dense matrix to COO matrix -void CooFromDense(hipStream_t stream, void** buffers, const char* opaque, +void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // CooMatvec: Product of COO matrix and dense vector. @@ -122,10 +121,10 @@ void CooFromDense(hipStream_t stream, void** buffers, const char* opaque, struct CooMatvecDescriptor { SparseMatDescriptor A; DenseVecDescriptor x, y; - hipsparseOperation_t op; + gpusparseOperation_t op; }; -void CooMatvec(hipStream_t stream, void** buffers, const char* opaque, +void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); // CooMatmat: Product of COO matrix and dense matrix. @@ -133,22 +132,24 @@ void CooMatvec(hipStream_t stream, void** buffers, const char* opaque, struct CooMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; - hipsparseOperation_t op_A; + gpusparseOperation_t op_A; }; -void CooMatmat(hipStream_t stream, void** buffers, const char* opaque, +void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +#endif // JAX_GPU_HAVE_SPARSE struct Gtsv2Descriptor { int m, n, ldb; }; -void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque, +void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len, XlaCustomCallStatus* status); -void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque, +void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len, XlaCustomCallStatus* status); +} // namespace JAX_GPU_NAMESPACE } // namespace jax -#endif // JAXLIB_HIPSPARSE_KERNELS_H_ +#endif // JAXLIB_GPU_SPARSE_KERNELS_H_ diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h new file mode 100644 index 000000000..87297cf01 --- /dev/null +++ b/jaxlib/gpu/vendor.h @@ -0,0 +1,442 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header is a shim that manages differences between CUDA and ROCM APIs. +// Jaxlib GPU kernels can be compiled for either CUDA or ROCM by defining +// JAX_GPU_CUDA or JAX_GPU_HIP respectively. + +#ifndef JAXLIB_GPU_VENDOR_H_ +#define JAXLIB_GPU_VENDOR_H_ + +#if defined(JAX_GPU_CUDA) + +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" +#include "third_party/gpus/cuda/include/cusparse.h" + +// Some sparse functionality is only available in CUSPARSE 11.3 or newer. +#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300) + +// CUDA-11.8 introduces FP8 E4M3/E5M2 types. +#define JAX_GPU_HAVE_FP8 (CUDA_VERSION >= 11080) + +#if JAX_GPU_HAVE_FP8 +#include "third_party/gpus/cuda/include/cuda_fp8.h" +#endif + +// cuSPARSE generic APIs are not supported on Windows until 11.0 +// cusparseIndexType_t is used in very limited scope so manually define will +// workaround compiling issue without harm. +#if defined(_WIN32) && (CUSPARSE_VERSION < 11000) +typedef enum { + CUSPARSE_INDEX_16U = 1, + CUSPARSE_INDEX_32I = 2, + CUSPARSE_INDEX_64I = 3 +} cusparseIndexType_t; +#endif + +#define JAX_GPU_NAMESPACE cuda +#define JAX_GPU_PREFIX "cu" + +typedef cuComplex gpuComplex; +typedef cuDoubleComplex gpuDoubleComplex; + +typedef cuComplex gpublasComplex; +typedef cuDoubleComplex gpublasDoubleComplex; +typedef cublasFillMode_t gpusolverFillMode_t; +typedef cublasStatus_t gpublasStatus_t; +typedef cublasHandle_t gpublasHandle_t; +typedef cudaDataType gpuDataType; +typedef cudaStream_t gpuStream_t; +typedef cudaError_t gpuError_t; +typedef cusolverDnHandle_t gpusolverDnHandle_t; +typedef cusolverStatus_t gpusolverStatus_t; +typedef cusolverEigMode_t gpusolverEigMode_t; +typedef syevjInfo gpuSyevjInfo; +typedef syevjInfo_t gpuSyevjInfo_t; +typedef cusparseIndexType_t gpusparseIndexType_t; +typedef cusparseHandle_t gpusparseHandle_t; +typedef cusparseOperation_t gpusparseOperation_t; +typedef cusparseStatus_t gpusparseStatus_t; +typedef cusparseSpMatDescr_t gpusparseSpMatDescr_t; +typedef cusparseDnMatDescr_t gpusparseDnMatDescr_t; +typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; + +#define GPU_C_16F CUDA_C_16F +#define GPU_R_16F CUDA_R_16F +#define GPU_C_32F CUDA_C_32F +#define GPU_R_32F CUDA_R_32F +#define GPU_C_64F CUDA_C_64F +#define GPU_R_64F CUDA_R_64F + +#define gpublasCreate cublasCreate +#define gpublasSetStream cublasSetStream +#define gpublasSgeqrfBatched cublasSgeqrfBatched +#define gpublasDgeqrfBatched cublasDgeqrfBatched +#define gpublasCgeqrfBatched cublasCgeqrfBatched +#define gpublasZgeqrfBatched cublasZgeqrfBatched +#define gpublasSgetrfBatched cublasSgetrfBatched +#define gpublasDgetrfBatched cublasDgetrfBatched +#define gpublasCgetrfBatched cublasCgetrfBatched +#define gpublasZgetrfBatched cublasZgetrfBatched + +#define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS + +#define gpusolverDnCreate cusolverDnCreate +#define gpusolverDnSetStream cusolverDnSetStream +#define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo +#define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo +#define gpusolverDnSpotrf cusolverDnSpotrf +#define gpusolverDnDpotrf cusolverDnDpotrf +#define gpusolverDnCpotrf cusolverDnCpotrf +#define gpusolverDnZpotrf cusolverDnZpotrf +#define gpusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + cusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, info, batch) +#define gpusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + cusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, info, batch) +#define gpusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + cusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, info, batch) +#define gpusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + cusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, info, batch) +#define gpusolverDnSpotrf_bufferSize cusolverDnSpotrf_bufferSize +#define gpusolverDnDpotrf_bufferSize cusolverDnDpotrf_bufferSize +#define gpusolverDnCpotrf_bufferSize cusolverDnCpotrf_bufferSize +#define gpusolverDnZpotrf_bufferSize cusolverDnZpotrf_bufferSize +#define gpusolverDnSgeqrf cusolverDnSgeqrf +#define gpusolverDnDgeqrf cusolverDnDgeqrf +#define gpusolverDnCgeqrf cusolverDnCgeqrf +#define gpusolverDnZgeqrf cusolverDnZgeqrf +#define gpusolverDnSgeqrf_bufferSize cusolverDnSgeqrf_bufferSize +#define gpusolverDnDgeqrf_bufferSize cusolverDnDgeqrf_bufferSize +#define gpusolverDnCgeqrf_bufferSize cusolverDnCgeqrf_bufferSize +#define gpusolverDnZgeqrf_bufferSize cusolverDnZgeqrf_bufferSize +#define gpusolverDnSgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + cusolverDnSgetrf(h, m, n, a, lda, work, ipiv, info) +#define gpusolverDnDgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + cusolverDnDgetrf(h, m, n, a, lda, work, ipiv, info) +#define gpusolverDnCgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + cusolverDnCgetrf(h, m, n, a, lda, work, ipiv, info) +#define gpusolverDnZgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + cusolverDnZgetrf(h, m, n, a, lda, work, ipiv, info) +#define gpusolverDnSgetrf_bufferSize cusolverDnSgetrf_bufferSize +#define gpusolverDnDgetrf_bufferSize cusolverDnDgetrf_bufferSize +#define gpusolverDnCgetrf_bufferSize cusolverDnCgetrf_bufferSize +#define gpusolverDnZgetrf_bufferSize cusolverDnZgetrf_bufferSize +#define gpusolverDnSorgqr cusolverDnSorgqr +#define gpusolverDnDorgqr cusolverDnDorgqr +#define gpusolverDnCungqr cusolverDnCungqr +#define gpusolverDnZungqr cusolverDnZungqr +#define gpusolverDnSorgqr_bufferSize cusolverDnSorgqr_bufferSize +#define gpusolverDnDorgqr_bufferSize cusolverDnDorgqr_bufferSize +#define gpusolverDnCungqr_bufferSize cusolverDnCungqr_bufferSize +#define gpusolverDnZungqr_bufferSize cusolverDnZungqr_bufferSize +#define gpusolverDnSsyevd cusolverDnSsyevd +#define gpusolverDnDsyevd cusolverDnDsyevd +#define gpusolverDnCheevd cusolverDnCheevd +#define gpusolverDnZheevd cusolverDnZheevd +#define gpusolverDnSsyevd_bufferSize cusolverDnSsyevd_bufferSize +#define gpusolverDnDsyevd_bufferSize cusolverDnDsyevd_bufferSize +#define gpusolverDnCheevd_bufferSize cusolverDnCheevd_bufferSize +#define gpusolverDnZheevd_bufferSize cusolverDnZheevd_bufferSize +#define gpusolverDnSsyevj cusolverDnSsyevj +#define gpusolverDnDsyevj cusolverDnDsyevj +#define gpusolverDnCheevj cusolverDnCheevj +#define gpusolverDnZheevj cusolverDnZheevj +#define gpusolverDnSsyevj_bufferSize cusolverDnSsyevj_bufferSize +#define gpusolverDnDsyevj_bufferSize cusolverDnDsyevj_bufferSize +#define gpusolverDnCheevj_bufferSize cusolverDnCheevj_bufferSize +#define gpusolverDnZheevj_bufferSize cusolverDnZheevj_bufferSize +#define gpusolverDnSsyevjBatched cusolverDnSsyevjBatched +#define gpusolverDnDsyevjBatched cusolverDnDsyevjBatched +#define gpusolverDnCheevjBatched cusolverDnCheevjBatched +#define gpusolverDnZheevjBatched cusolverDnZheevjBatched +#define gpusolverDnSsyevjBatched_bufferSize cusolverDnSsyevjBatched_bufferSize +#define gpusolverDnDsyevjBatched_bufferSize cusolverDnDsyevjBatched_bufferSize +#define gpusolverDnCheevjBatched_bufferSize cusolverDnCheevjBatched_bufferSize +#define gpusolverDnZheevjBatched_bufferSize cusolverDnZheevjBatched_bufferSize +#define gpusolverDnSgesvd cusolverDnSgesvd +#define gpusolverDnDgesvd cusolverDnDgesvd +#define gpusolverDnCgesvd cusolverDnCgesvd +#define gpusolverDnZgesvd cusolverDnZgesvd +#define gpusolverDnSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + cusolverDnSgesvd_bufferSize(h, m, n, lwork) +#define gpusolverDnDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + cusolverDnDgesvd_bufferSize(h, m, n, lwork) +#define gpusolverDnCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + cusolverDnCgesvd_bufferSize(h, m, n, lwork) +#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + cusolverDnZgesvd_bufferSize(h, m, n, lwork) + +#define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER +#define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER +#define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS + +#define gpusparseCooSetStridedBatch cusparseCooSetStridedBatch +#define gpusparseCreate cusparseCreate +#define gpusparseCreateCoo cusparseCreateCoo +#define gpusparseCreateCsr cusparseCreateCsr +#define gpusparseCreateDnMat cusparseCreateDnMat +#define gpusparseCreateDnVec cusparseCreateDnVec +#define gpusparseDenseToSparse_analysis cusparseDenseToSparse_analysis +#define gpusparseDenseToSparse_bufferSize cusparseDenseToSparse_bufferSize +#define gpusparseDenseToSparse_convert cusparseDenseToSparse_convert +#define gpusparseDestroySpMat cusparseDestroySpMat +#define gpusparseDestroyDnMat cusparseDestroyDnMat +#define gpusparseDestroyDnVec cusparseDestroyDnVec +#define gpusparseDnMatSetStridedBatch cusparseDnMatSetStridedBatch +#define gpusparseSetStream cusparseSetStream +#define gpusparseSparseToDense cusparseSparseToDense +#define gpusparseSparseToDense_bufferSize cusparseSparseToDense_bufferSize +#define gpusparseSpMM cusparseSpMM +#define gpusparseSpMM_bufferSize cusparseSpMM_bufferSize +#define gpusparseSpMV cusparseSpMV +#define gpusparseSpMV_bufferSize cusparseSpMV_bufferSize +#define gpusparseSgtsv2 cusparseSgtsv2 +#define gpusparseDgtsv2 cusparseDgtsv2 +#define gpusparseSgtsv2_bufferSizeExt cusparseSgtsv2_bufferSizeExt +#define gpusparseDgtsv2_bufferSizeExt cusparseDgtsv2_bufferSizeExt + +#define GPUSPARSE_INDEX_16U CUSPARSE_INDEX_16U +#define GPUSPARSE_INDEX_32I CUSPARSE_INDEX_32I +#define GPUSPARSE_INDEX_64I CUSPARSE_INDEX_64I +#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT CUSPARSE_DENSETOSPARSE_ALG_DEFAULT +#define GPUSPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO +#define GPUSPARSE_MV_ALG_DEFAULT CUSPARSE_MV_ALG_DEFAULT +#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE +#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE +#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW +#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT +#define GPUSPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT +#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS + +#define gpuGetLastError cudaGetLastError +#define gpuGetErrorString cudaGetErrorString +#define gpuMemcpyAsync cudaMemcpyAsync +#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice +#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice +#define gpuStreamSynchronize cudaStreamSynchronize +#define gpuSuccess cudaSuccess + +#elif defined(JAX_GPU_HIP) + +#include "rocm/include/hip/hip_runtime_api.h" +#include "rocm/include/hipblas.h" +#include "rocm/include/hipsolver.h" +#include "rocm/include/hipsparse.h" + +#define JAX_GPU_NAMESPACE hip +#define JAX_GPU_PREFIX "hip" + +#define JAX_GPU_HAVE_SPARSE 1 +#define JAX_GPU_HAVE_FP8 0 + +typedef hipFloatComplex gpuComplex; +typedef hipDoubleComplex gpuDoubleComplex; + +typedef hipblasComplex gpublasComplex; +typedef hipblasDoubleComplex gpublasDoubleComplex; +typedef hipsolverHandle_t gpusolverDnHandle_t; +typedef hipblasFillMode_t gpublasFillMode_t; +typedef hipsolverFillMode_t gpusolverFillMode_t; +typedef hipblasHandle_t gpublasHandle_t; +typedef hipblasStatus_t gpublasStatus_t; +typedef hipDataType gpuDataType; +typedef hipStream_t gpuStream_t; +typedef hipError_t gpuError_t; +typedef void gpuSyevjInfo; +typedef hipsolverSyevjInfo_t gpuSyevjInfo_t; +typedef hipsolverEigMode_t gpusolverEigMode_t; +typedef hipsolverStatus_t gpusolverStatus_t; +typedef hipsparseIndexType_t gpusparseIndexType_t; +typedef hipsparseHandle_t gpusparseHandle_t; +typedef hipsparseOperation_t gpusparseOperation_t; +typedef hipsparseStatus_t gpusparseStatus_t; +typedef hipsparseSpMatDescr_t gpusparseSpMatDescr_t; +typedef hipsparseDnMatDescr_t gpusparseDnMatDescr_t; +typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; + +#define GPU_C_16F HIP_C_16F +#define GPU_R_16F HIP_R_16F +#define GPU_C_32F HIP_C_32F +#define GPU_R_32F HIP_R_32F +#define GPU_C_64F HIP_C_64F +#define GPU_R_64F HIP_R_64F + +#define gpublasCreate hipblasCreate +#define gpublasSetStream hipblasSetStream +#define gpublasSgeqrfBatched hipblasSgeqrfBatched +#define gpublasDgeqrfBatched hipblasDgeqrfBatched +#define gpublasCgeqrfBatched hipblasCgeqrfBatched +#define gpublasZgeqrfBatched hipblasZgeqrfBatched +#define gpublasSgetrfBatched hipblasSgetrfBatched +#define gpublasDgetrfBatched hipblasDgetrfBatched +#define gpublasCgetrfBatched hipblasCgetrfBatched +#define gpublasZgetrfBatched hipblasZgetrfBatched + +#define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS + +#define gpusolverDnCreate hipsolverCreate +#define gpusolverDnSetStream hipsolverSetStream +#define gpusolverDnCreateSyevjInfo hipsolverCreateSyevjInfo +#define gpusolverDnDestroySyevjInfo hipsolverDestroySyevjInfo +#define gpusolverDnSpotrf hipsolverSpotrf +#define gpusolverDnDpotrf hipsolverDpotrf +#define gpusolverDnCpotrf hipsolverCpotrf +#define gpusolverDnZpotrf hipsolverZpotrf +#define gpusolverDnSpotrf_bufferSize hipsolverSpotrf_bufferSize +#define gpusolverDnDpotrf_bufferSize hipsolverDpotrf_bufferSize +#define gpusolverDnCpotrf_bufferSize hipsolverCpotrf_bufferSize +#define gpusolverDnZpotrf_bufferSize hipsolverZpotrf_bufferSize +#define gpusolverDnSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + hipsolverSpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch) +#define gpusolverDnDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + hipsolverDpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch) +#define gpusolverDnCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + hipsolverCpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch) +#define gpusolverDnZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, \ + batch) \ + hipsolverZpotrfBatched(h, uplo, n, ptrs, lda, work, lwork, info, batch) +#define gpusolverDnSgeqrf hipsolverSgeqrf +#define gpusolverDnDgeqrf hipsolverDgeqrf +#define gpusolverDnCgeqrf hipsolverCgeqrf +#define gpusolverDnZgeqrf hipsolverZgeqrf +#define gpusolverDnSgeqrf_bufferSize hipsolverSgeqrf_bufferSize +#define gpusolverDnDgeqrf_bufferSize hipsolverDgeqrf_bufferSize +#define gpusolverDnCgeqrf_bufferSize hipsolverCgeqrf_bufferSize +#define gpusolverDnZgeqrf_bufferSize hipsolverZgeqrf_bufferSize +#define gpusolverDnSgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + hipsolverSgetrf(h, m, n, a, lda, work, lwork, ipiv, info) +#define gpusolverDnDgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + hipsolverDgetrf(h, m, n, a, lda, work, lwork, ipiv, info) +#define gpusolverDnCgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + hipsolverCgetrf(h, m, n, a, lda, work, lwork, ipiv, info) +#define gpusolverDnZgetrf(h, m, n, a, lda, work, lwork, ipiv, info) \ + hipsolverZgetrf(h, m, n, a, lda, work, lwork, ipiv, info) +#define gpusolverDnSgetrf_bufferSize hipsolverSgetrf_bufferSize +#define gpusolverDnDgetrf_bufferSize hipsolverDgetrf_bufferSize +#define gpusolverDnCgetrf_bufferSize hipsolverCgetrf_bufferSize +#define gpusolverDnZgetrf_bufferSize hipsolverZgetrf_bufferSize +#define gpusolverDnSorgqr hipsolverSorgqr +#define gpusolverDnDorgqr hipsolverDorgqr +#define gpusolverDnCungqr hipsolverCungqr +#define gpusolverDnZungqr hipsolverZungqr +#define gpusolverDnSorgqr_bufferSize hipsolverSorgqr_bufferSize +#define gpusolverDnDorgqr_bufferSize hipsolverDorgqr_bufferSize +#define gpusolverDnCungqr_bufferSize hipsolverCungqr_bufferSize +#define gpusolverDnZungqr_bufferSize hipsolverZungqr_bufferSize +#define gpusolverDnSsyevd hipsolverSsyevd +#define gpusolverDnDsyevd hipsolverDsyevd +#define gpusolverDnCheevd hipsolverCheevd +#define gpusolverDnZheevd hipsolverZheevd +#define gpusolverDnSsyevd_bufferSize hipsolverSsyevd_bufferSize +#define gpusolverDnDsyevd_bufferSize hipsolverDsyevd_bufferSize +#define gpusolverDnCheevd_bufferSize hipsolverCheevd_bufferSize +#define gpusolverDnZheevd_bufferSize hipsolverZheevd_bufferSize +#define gpusolverDnSsyevj hipsolverSsyevj +#define gpusolverDnDsyevj hipsolverDsyevj +#define gpusolverDnCheevj hipsolverCheevj +#define gpusolverDnZheevj hipsolverZheevj +#define gpusolverDnSsyevj_bufferSize hipsolverSsyevj_bufferSize +#define gpusolverDnDsyevj_bufferSize hipsolverDsyevj_bufferSize +#define gpusolverDnCheevj_bufferSize hipsolverCheevj_bufferSize +#define gpusolverDnZheevj_bufferSize hipsolverZheevj_bufferSize +#define gpusolverDnSsyevjBatched hipsolverSsyevjBatched +#define gpusolverDnDsyevjBatched hipsolverDsyevjBatched +#define gpusolverDnCheevjBatched hipsolverCheevjBatched +#define gpusolverDnZheevjBatched hipsolverZheevjBatched +#define gpusolverDnSsyevjBatched_bufferSize hipsolverSsyevjBatched_bufferSize +#define gpusolverDnDsyevjBatched_bufferSize hipsolverDsyevjBatched_bufferSize +#define gpusolverDnCheevjBatched_bufferSize hipsolverCheevjBatched_bufferSize +#define gpusolverDnZheevjBatched_bufferSize hipsolverZheevjBatched_bufferSize +#define gpusolverDnSgesvd hipsolverSgesvd +#define gpusolverDnDgesvd hipsolverDgesvd +#define gpusolverDnCgesvd hipsolverCgesvd +#define gpusolverDnZgesvd hipsolverZgesvd +#define gpusolverDnSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + hipsolverSgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) +#define gpusolverDnDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + hipsolverDgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) +#define gpusolverDnCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + hipsolverCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) +#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ + hipsolverZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) + +#define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER +#define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER +#define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS + +#define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch +#define gpusparseCreate hipsparseCreate +#define gpusparseSetStream hipsparseSetStream +#define gpusparseCreateCoo hipsparseCreateCoo +#define gpusparseCreateCsr hipsparseCreateCsr +#define gpusparseCreateDnMat hipsparseCreateDnMat +#define gpusparseCreateDnVec hipsparseCreateDnVec +#define gpusparseDenseToSparse_analysis hipsparseDenseToSparse_analysis +#define gpusparseDenseToSparse_bufferSize hipsparseDenseToSparse_bufferSize +#define gpusparseDenseToSparse_convert hipsparseDenseToSparse_convert +#define gpusparseDestroySpMat hipsparseDestroySpMat +#define gpusparseDestroyDnMat hipsparseDestroyDnMat +#define gpusparseDestroyDnVec hipsparseDestroyDnVec +#define gpusparseDnMatSetStridedBatch hipsparseDnMatSetStridedBatch +#define gpusparseSparseToDense hipsparseSparseToDense +#define gpusparseSparseToDense_bufferSize hipsparseSparseToDense_bufferSize +#define gpusparseSpMM hipsparseSpMM +#define gpusparseSpMM_bufferSize hipsparseSpMM_bufferSize +#define gpusparseSpMV hipsparseSpMV +#define gpusparseSpMV_bufferSize hipsparseSpMV_bufferSize +#define gpusparseSgtsv2 hipsparseSgtsv2 +#define gpusparseDgtsv2 hipsparseDgtsv2 +#define gpusparseSgtsv2_bufferSizeExt hipsparseSgtsv2_bufferSizeExt +#define gpusparseDgtsv2_bufferSizeExt hipsparseDgtsv2_bufferSizeExt + +#define GPUSPARSE_INDEX_16U HIPSPARSE_INDEX_16U +#define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I +#define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I +#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT +#define GPUSPARSE_MV_ALG_DEFAULT HIPSPARSE_MV_ALG_DEFAULT +#define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO +#define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE +#define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE +#define GPUSPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW +#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT +#define GPUSPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT +#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS + +#define gpuGetLastError hipGetLastError +#define gpuGetErrorString hipGetErrorString +#define gpuMemcpyAsync hipMemcpyAsync +#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define gpuMemcpyHostToDevice hipMemcpyHostToDevice +#define gpuStreamSynchronize hipStreamSynchronize +#define gpuSuccess hipSuccess + +#else // defined(GPU vendor) +#error "Either JAX_GPU_CUDA or JAX_GPU_HIP must be defined" +#endif // defined(GPU vendor) + +#endif // JAXLIB_GPU_VENDOR_H_ diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index a9c432da7..1a0755c00 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -23,14 +23,14 @@ from .mhlo_helpers import custom_call from jaxlib import xla_client try: - from .cuda import _cuda_linalg + from .cuda import _linalg as _cuda_linalg for _name, _value in _cuda_linalg.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cuda_linalg = None try: - from .rocm import _hip_linalg + from .rocm import _linalg as _hip_linalg for _name, _value in _hip_linalg.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: @@ -65,7 +65,7 @@ def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_ operand_layouts=[pivots_layout], result_layouts=[permutations_layout]) -cuda_lu_pivots_to_permutation = partial( - _lu_pivots_to_permutation_mhlo, "cuda", _cuda_linalg) +cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_mhlo, "cu", + _cuda_linalg) hip_lu_pivots_to_permutation = partial( _lu_pivots_to_permutation_mhlo, "hip", _hip_linalg) diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index d7f5965de..7b52eadaf 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -25,14 +25,14 @@ from jaxlib import xla_client from .mhlo_helpers import custom_call try: - from .cuda import _cuda_prng + from .cuda import _prng as _cuda_prng for _name, _value in _cuda_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cuda_prng = None try: - from .rocm import _hip_prng + from .rocm import _prng as _hip_prng for _name, _value in _hip_prng.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: @@ -64,5 +64,6 @@ def _threefry2x32_lowering(prng, platform, keys, data): operand_layouts=[layout] * 4, result_layouts=[layout] * 2) -cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cuda") + +cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu") rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip") diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 048d575c0..d22375565 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -27,14 +27,14 @@ from jaxlib import xla_client from .mhlo_helpers import custom_call try: - from .cuda import _cublas + from .cuda import _blas as _cublas for _name, _value in _cublas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: _cublas = None try: - from .cuda import _cusolver + from .cuda import _solver as _cusolver for _name, _value in _cusolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: @@ -42,14 +42,14 @@ except ImportError: try: - from .rocm import _hipblas + from .rocm import _blas as _hipblas for _name, _value in _hipblas.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: _hipblas = None try: - from .rocm import _hipsolver + from .rocm import _solver as _hipsolver for _name, _value in _hipsolver.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") except ImportError: diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index 99d459d6a..dd374bb8a 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -26,7 +26,7 @@ from jaxlib import xla_client from .mhlo_helpers import custom_call try: - from .cuda import _cusparse + from .cuda import _sparse as _cusparse except ImportError: _cusparse = None else: @@ -34,7 +34,7 @@ else: xla_client.register_custom_call_target(_name, _value, platform="CUDA") try: - from .rocm import _hipsparse + from .rocm import _sparse as _hipsparse except ImportError: _hipsparse = None else: @@ -42,8 +42,8 @@ else: xla_client.register_custom_call_target(_name, _value, platform="ROCM") -cuda_is_supported : bool = _cusparse and _cusparse.cusparse_supported -rocm_is_supported : bool = _hipsparse and _hipsparse.hipsparse_supported +cuda_is_supported : bool = _cusparse and _cusparse.sparse_supported +rocm_is_supported : bool = _hipsparse and _hipsparse.sparse_supported def _validate_csr_mhlo(data, indices, indptr, shape): diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index e4344cfa6..50eb4f258 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -25,16 +25,27 @@ licenses(["notice"]) package(default_visibility = ["//:__subpackages__"]) +cc_library( + name = "hip_vendor", + hdrs = [ + "//jaxlib/gpu:vendor.h", + ], + defines = ["JAX_GPU_HIP=1"], + deps = [ + "@local_config_rocm//rocm:rocm_headers", + ], +) cc_library( name = "hip_gpu_kernel_helpers", - srcs = if_rocm_is_configured(["hip_gpu_kernel_helpers.cc"]), - hdrs = if_rocm_is_configured(["hip_gpu_kernel_helpers.h"]), + srcs = if_rocm_is_configured(["//jaxlib/gpu:gpu_kernel_helpers.cc"]), + hdrs = if_rocm_is_configured(["//jaxlib/gpu:gpu_kernel_helpers.h"]), copts = [ "-fexceptions", ], features = ["-use_header_modules"], deps = [ + ":hip_vendor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -46,11 +57,12 @@ cc_library( cc_library( name = "hipblas_kernels", - srcs = ["hipblas_kernels.cc"], - hdrs = ["hipblas_kernels.h"], + srcs = ["//jaxlib/gpu:blas_kernels.cc"], + hdrs = ["//jaxlib/gpu:blas_kernels.h"], deps = [ - "//jaxlib:handle_pool", ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -68,15 +80,16 @@ cc_library( ) pybind_extension( - name = "_hipblas", - srcs = ["hipblas.cc"], + name = "_blas", + srcs = ["//jaxlib/gpu:blas.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_hipblas", + module_name = "_blas", deps = [ + ":hip_vendor", ":hipblas_kernels", "//jaxlib:kernel_pybind11_helpers", "@com_google_absl//absl/container:flat_hash_map", @@ -89,11 +102,12 @@ pybind_extension( cc_library( name = "hipsolver_kernels", - srcs = ["hipsolver_kernels.cc"], - hdrs = ["hipsolver_kernels.h"], + srcs = ["//jaxlib/gpu:solver_kernels.cc"], + hdrs = ["//jaxlib/gpu:solver_kernels.h"], deps = [ - "//jaxlib:handle_pool", ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -105,16 +119,17 @@ cc_library( ) pybind_extension( - name = "_hipsolver", - srcs = ["hipsolver.cc"], + name = "_solver", + srcs = ["//jaxlib/gpu:solver.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_hipsolver", + module_name = "_solver", deps = [ ":hip_gpu_kernel_helpers", + ":hip_vendor", ":hipsolver_kernels", "//jaxlib:kernel_pybind11_helpers", "@com_google_absl//absl/container:flat_hash_map", @@ -127,11 +142,12 @@ pybind_extension( cc_library( name = "hipsparse_kernels", - srcs = ["hipsparse_kernels.cc"], - hdrs = ["hipsparse_kernels.h"], + srcs = ["//jaxlib/gpu:sparse_kernels.cc"], + hdrs = ["//jaxlib/gpu:sparse_kernels.h"], deps = [ - "//jaxlib:handle_pool", ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -143,16 +159,17 @@ cc_library( ) pybind_extension( - name = "_hipsparse", - srcs = ["hipsparse.cc"], + name = "_sparse", + srcs = ["//jaxlib/gpu:sparse.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_hipsparse", + module_name = "_sparse", deps = [ ":hip_gpu_kernel_helpers", + ":hip_vendor", ":hipsparse_kernels", "//jaxlib:kernel_pybind11_helpers", "@com_google_absl//absl/algorithm:container", @@ -172,13 +189,12 @@ pybind_extension( cc_library( name = "hip_lu_pivot_kernels", - srcs = [ - "hip_lu_pivot_kernels.cc", - ], - hdrs = ["hip_lu_pivot_kernels.h"], + srcs = ["//jaxlib/gpu:lu_pivot_kernels.cc"], + hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"], deps = [ ":hip_gpu_kernel_helpers", ":hip_lu_pivot_kernels_impl", + ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", @@ -187,12 +203,11 @@ cc_library( rocm_library( name = "hip_lu_pivot_kernels_impl", - srcs = [ - "hip_lu_pivot_kernels.hip.cc", - ], - hdrs = ["hip_lu_pivot_kernels.h"], + srcs = ["//jaxlib/gpu:lu_pivot_kernels.cu.cc"], + hdrs = ["//jaxlib/gpu:lu_pivot_kernels.h"], deps = [ ":hip_gpu_kernel_helpers", + ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", @@ -200,18 +215,19 @@ rocm_library( ) pybind_extension( - name = "_hip_linalg", - srcs = ["hip_linalg.cc"], + name = "_linalg", + srcs = ["//jaxlib/gpu:linalg.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_hip_linalg", + module_name = "_linalg", deps = [ ":hip_gpu_kernel_helpers", ":hip_lu_pivot_kernels", ":hip_lu_pivot_kernels_impl", + ":hip_vendor", "//jaxlib:kernel_pybind11_helpers", "@local_config_rocm//rocm:rocm_headers", "@pybind11", @@ -220,13 +236,12 @@ pybind_extension( cc_library( name = "hip_prng_kernels", - srcs = [ - "hip_prng_kernels.cc", - ], - hdrs = ["hip_prng_kernels.h"], + srcs = ["//jaxlib/gpu:prng_kernels.cc"], + hdrs = ["//jaxlib/gpu:prng_kernels.h"], deps = [ ":hip_gpu_kernel_helpers", ":hip_prng_kernels_impl", + ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", @@ -235,12 +250,11 @@ cc_library( rocm_library( name = "hip_prng_kernels_impl", - srcs = [ - "hip_prng_kernels.hip.cc", - ], - hdrs = ["hip_prng_kernels.h"], + srcs = ["//jaxlib/gpu:prng_kernels.cu.cc"], + hdrs = ["//jaxlib/gpu:prng_kernels.h"], deps = [ ":hip_gpu_kernel_helpers", + ":hip_vendor", "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", @@ -248,17 +262,18 @@ rocm_library( ) pybind_extension( - name = "_hip_prng", - srcs = ["hip_prng.cc"], + name = "_prng", + srcs = ["//jaxlib/gpu:prng.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_hip_prng", + module_name = "_prng", deps = [ ":hip_gpu_kernel_helpers", ":hip_prng_kernels", + ":hip_vendor", "//jaxlib:kernel_pybind11_helpers", "@local_config_rocm//rocm:rocm_headers", "@pybind11", @@ -268,11 +283,10 @@ pybind_extension( py_library( name = "rocm_gpu_support", deps = [ - ":_hip_linalg", - ":_hip_prng", - ":_hipblas", - ":_hipsolver", - ":_hipsparse", + ":_blas", + ":_linalg", + ":_prng", + ":_solver", + ":_sparse", ], ) - diff --git a/jaxlib/rocm/hip_gpu_kernel_helpers.h b/jaxlib/rocm/hip_gpu_kernel_helpers.h deleted file mode 100644 index 42bd4b11b..000000000 --- a/jaxlib/rocm/hip_gpu_kernel_helpers.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_HIP_GPU_KERNEL_HELPERS_H_ -#define JAXLIB_HIP_GPU_KERNEL_HELPERS_H_ - -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipblas.h" -#include "rocm/include/hipsolver.h" -#include "rocm/include/hipsparse.h" - -#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr) - -#define JAX_THROW_IF_ERROR(expr) \ - { \ - auto s___ = (expr); \ - if (!s___.ok()) \ - throw std::runtime_error(std::string(s___.message())); \ - } - -#define JAX_RETURN_IF_ERROR(expr) \ - { \ - auto s___ = (expr); \ - if (!s___.ok()) \ - return s___; \ - } - -namespace jax { - -// Used via JAX_AS_STATUS(expr) macro. -absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line, - const char* expr); -absl::Status AsStatus(hipsolverStatus_t status, const char* file, - std::int64_t line, const char* expr); -absl::Status AsStatus(hipsparseStatus_t status, const char* file, - std::int64_t line, const char* expr); -absl::Status AsStatus(hipblasStatus_t status, const char* file, - std::int64_t line, const char* expr); - -// Builds an array of pointers to each array in a batch, in device memory. -// Caution: the return value must be kept alive (e.g., via a stream -// synchronization) until the copy enqueued by MakeBatchPointers on `stream` -// completes. -absl::StatusOr> -MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch, - int batch_elem_size); - -} // namespace jax - -#endif // JAXLIB_HIP_GPU_KERNEL_HELPERS_H_ \ No newline at end of file diff --git a/jaxlib/rocm/hip_linalg.cc b/jaxlib/rocm/hip_linalg.cc deleted file mode 100644 index b891f256b..000000000 --- a/jaxlib/rocm/hip_linalg.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "include/pybind11/pybind11.h" -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" -#include "jaxlib/rocm/hip_lu_pivot_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" - -namespace jax { -namespace { - -std::string -BuildHipLuPivotsToPermutationDescriptor(std::int64_t batch_size, - std::int32_t pivot_size, - std::int32_t permutation_size) { - return PackDescriptorAsString(LuPivotsToPermutationDescriptor{ - batch_size, pivot_size, permutation_size}); -} - -pybind11::dict Registrations() { - pybind11::dict dict; - dict["hip_lu_pivots_to_permutation"] = - EncapsulateFunction(HipLuPivotsToPermutation); - return dict; -} - -PYBIND11_MODULE(_hip_linalg, m) { - m.def("registrations", &Registrations); - m.def("lu_pivots_to_permutation_descriptor", - [](std::int64_t batch_size, std::int32_t pivot_size, - std::int32_t permutation_size) { - std::string result = BuildHipLuPivotsToPermutationDescriptor( - batch_size, pivot_size, permutation_size); - return pybind11::bytes(result); - }); -} - -} // namespace -} // namespace jax diff --git a/jaxlib/rocm/hip_prng.cc b/jaxlib/rocm/hip_prng.cc deleted file mode 100644 index ba7784c52..000000000 --- a/jaxlib/rocm/hip_prng.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/rocm/hip_prng_kernels.h" - -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" -#include "jaxlib/kernel_pybind11_helpers.h" -#include "include/pybind11/pybind11.h" - -namespace jax { -namespace { - -std::string BuildHipThreeFry2x32Descriptor(std::int64_t n) { - return PackDescriptorAsString(ThreeFry2x32Descriptor{n}); -} -pybind11::dict Registrations() { - pybind11::dict dict; - dict["hip_threefry2x32"] = EncapsulateFunction(HipThreeFry2x32); - return dict; -} - -PYBIND11_MODULE(_hip_prng, m) { - m.def("registrations", &Registrations); - m.def("threefry2x32_descriptor", [](std::int64_t n) { - std::string result = BuildHipThreeFry2x32Descriptor(n); - return pybind11::bytes(result); - }); -} - -} // namespace -} // namespace jax diff --git a/jaxlib/rocm/hip_prng_kernels.cc b/jaxlib/rocm/hip_prng_kernels.cc deleted file mode 100644 index a347c0276..000000000 --- a/jaxlib/rocm/hip_prng_kernels.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/rocm/hip_prng_kernels.h" - -#include - -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" -#include "jaxlib/kernel_helpers.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { -namespace { - -absl::Status HipThreeFry2x32_(hipStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - LaunchThreeFry2x32Kernel(stream, buffers, **s); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError())); - return absl::OkStatus(); -} - -} // namespace - -void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = HipThreeFry2x32_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - std::string_view message = s.message(); - XlaCustomCallStatusSetFailure(status, message.data(), message.length()); - } -} - -} // namespace jax diff --git a/jaxlib/rocm/hip_prng_kernels.h b/jaxlib/rocm/hip_prng_kernels.h deleted file mode 100644 index 4bc17253a..000000000 --- a/jaxlib/rocm/hip_prng_kernels.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_HIP_PRNG_KERNELS_H_ -#define JAXLIB_HIP_PRNG_KERNELS_H_ - -#include -#include - -#include "rocm/include/hip/hip_runtime_api.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { - -struct ThreeFry2x32Descriptor { - std::int64_t n; -}; - -void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers, - ThreeFry2x32Descriptor descriptor); - -void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace jax - -#endif // JAXLIB_HIP_PRNG_KERNELS_H_ \ No newline at end of file diff --git a/jaxlib/rocm/hip_prng_kernels.hip.cc b/jaxlib/rocm/hip_prng_kernels.hip.cc deleted file mode 100644 index 080195e3d..000000000 --- a/jaxlib/rocm/hip_prng_kernels.hip.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/rocm/hip_prng_kernels.h" - -#include -#include - -namespace jax { -namespace { - -__global__ void -ThreeFry2x32Kernel(const std::uint32_t* key0, const std::uint32_t* key1, - const std::uint32_t* data0, const std::uint32_t* data1, - std::uint32_t* out0, std::uint32_t* out1, std::int64_t n) { - for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n; - idx += blockDim.x * gridDim.x) { - // Rotation distances specified by the Threefry2x32 algorithm. - std::uint32_t rotations[8] = {13, 15, 26, 6, 17, 29, 16, 24}; - std::uint32_t x[2]; - std::uint32_t ks[3]; - - // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = 0x1BD11BDA; - - ks[0] = key0[idx]; - x[0] = data0[idx]; - ks[2] = ks[2] ^ key0[idx]; - - ks[1] = key1[idx]; - x[1] = data1[idx]; - ks[2] = ks[2] ^ key1[idx]; - - auto rotate_left = [](std::uint32_t v, std::uint32_t distance) { - return (v << distance) | (v >> (32 - distance)); - }; - - // Performs a single round of the Threefry2x32 algorithm, with a rotation - // amount 'rotation'. - auto round = [&](std::uint32_t* v, std::uint32_t rotation) { - v[0] += v[1]; - v[1] = rotate_left(v[1], rotation); - v[1] ^= v[0]; - }; - - // There are no known statistical flaws with 13 rounds of Threefry2x32. - // We are conservative and use 20 rounds. - x[0] = x[0] + ks[0]; - x[1] = x[1] + ks[1]; - for (int i = 0; i < 4; ++i) { - round(x, rotations[i]); - } - - x[0] = x[0] + ks[1]; - x[1] = x[1] + ks[2] + 1u; - for (int i = 4; i < 8; ++i) { - round(x, rotations[i]); - } - - x[0] = x[0] + ks[2]; - x[1] = x[1] + ks[0] + 2u; - for (int i = 0; i < 4; ++i) { - round(x, rotations[i]); - } - - x[0] = x[0] + ks[0]; - x[1] = x[1] + ks[1] + 3u; - for (int i = 4; i < 8; ++i) { - round(x, rotations[i]); - } - - x[0] = x[0] + ks[1]; - x[1] = x[1] + ks[2] + 4u; - for (int i = 0; i < 4; ++i) { - round(x, rotations[i]); - } - - out0[idx] = x[0] + ks[2]; - out1[idx] = x[1] + ks[0] + 5u; - } -} - -} // namespace - -void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers, - ThreeFry2x32Descriptor descriptor) { - std::array keys; - keys[0] = reinterpret_cast(buffers[0]); - keys[1] = reinterpret_cast(buffers[1]); - std::array data; - data[0] = reinterpret_cast(buffers[2]); - data[1] = reinterpret_cast(buffers[3]); - std::array out; - out[0] = reinterpret_cast(buffers[4]); - out[1] = reinterpret_cast(buffers[5]); - const int block_dim = 128; - const std::int64_t grid_dim = - std::min(1024, (descriptor.n + block_dim - 1) / block_dim); - ThreeFry2x32Kernel<<>>(keys[0], keys[1], data[0], data[1], out[0], - out[1], descriptor.n); -} - -} // namespace jax diff --git a/jaxlib/rocm/hipblas_kernels.h b/jaxlib/rocm/hipblas_kernels.h deleted file mode 100644 index f377665dc..000000000 --- a/jaxlib/rocm/hipblas_kernels.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_HIPBLAS_KERNELS_H_ -#define JAXLIB_HIPBLAS_KERNELS_H_ - -#include - -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipblas.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { - -// Set of types known to Hipsolver. -enum class HipblasType { - F32, - F64, - C64, - C128, -}; - -// Batched LU decomposition: getrfbatched - -struct GetrfBatchedDescriptor { - HipblasType type; - int batch, n; -}; - -void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Batched QR decomposition: geqrfbatched - -struct GeqrfBatchedDescriptor { - HipblasType type; - int batch, m, n; -}; - -void GeqrfBatched(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace jax - -#endif // JAXLIB_HIPBLAS_KERNELS_H_ diff --git a/jaxlib/rocm/hipsolver.cc b/jaxlib/rocm/hipsolver.cc deleted file mode 100644 index c07ae754a..000000000 --- a/jaxlib/rocm/hipsolver.cc +++ /dev/null @@ -1,435 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "include/pybind11/numpy.h" -#include "include/pybind11/pybind11.h" -#include "include/pybind11/stl.h" -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" -#include "jaxlib/rocm/hipsolver_kernels.h" -#include "jaxlib/kernel_pybind11_helpers.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipsolver.h" - -namespace jax { -namespace { -namespace py = pybind11; - -// Converts a NumPy dtype to a Type. -HipsolverType DtypeToHipsolverType(const py::dtype& np_type) { - static auto* types = - new absl::flat_hash_map, HipsolverType>({ - {{'f', 4}, HipsolverType::F32}, - {{'f', 8}, HipsolverType::F64}, - {{'c', 8}, HipsolverType::C64}, - {{'c', 16}, HipsolverType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", py::repr(np_type))); - } - return it->second; -} - -// potrf: Cholesky decomposition - -// Returns the workspace size and a descriptor for a potrf operation. -std::pair BuildPotrfDescriptor(const py::dtype& dtype, - bool lower, int b, int n) { - HipsolverType type = DtypeToHipsolverType(dtype); - auto h = SolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - std::int64_t workspace_size; - hipsolverFillMode_t uplo = - lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER; - if (b == 1) { - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverSpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); - workspace_size = lwork * sizeof(float); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverDpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); - workspace_size = lwork * sizeof(double); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverCpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); - workspace_size = lwork * sizeof(hipComplex); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverZpotrf_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork))); - workspace_size = lwork * sizeof(hipDoubleComplex); - break; - } - } else { - // TODO(rocm): when cuda and hip had same API for batched potrf, remove this - // batched potrf has different API compared to CUDA. In hip we still need to create the workspace and additional space to copy the batch array pointers - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverSpotrfBatched_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork, b))); - workspace_size = (lwork * sizeof(float)) + (b * sizeof(float*)); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverDpotrfBatched_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork, b))); - workspace_size = (lwork * sizeof(double)) + (b * sizeof(double*)); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverCpotrfBatched_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork, b))); - workspace_size = (lwork * sizeof(hipComplex)) + (b * sizeof(hipComplex*)); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverZpotrfBatched_bufferSize(handle.get(), uplo, n, - /*A=*/nullptr, - /*lda=*/n, &lwork, b))); - workspace_size = (lwork * sizeof(hipDoubleComplex)) + (b * sizeof(hipDoubleComplex*)); - break; - } - - } - return {workspace_size, - PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})}; -} - -// getrf: LU decomposition - -// Returns the workspace size and a descriptor for a getrf operation. -std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, - int m, int n) { - HipsolverType type = DtypeToHipsolverType(dtype); - auto h = SolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})}; -} - -// geqrf: QR decomposition - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildGeqrfDescriptor(const py::dtype& dtype, int b, - int m, int n) { - HipsolverType type = DtypeToHipsolverType(dtype); - auto h = SolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; -} - -// orgqr/ungqr: apply elementary Householder transformations - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, - int m, int n, int k) { - HipsolverType type = DtypeToHipsolverType(dtype); - auto h = SolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(hipsolverZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; -} - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -// Returns the workspace size and a descriptor for a syevd operation. -std::pair BuildSyevdDescriptor(const py::dtype& dtype, - bool lower, int b, int n) { - HipsolverType type = DtypeToHipsolverType(dtype); - auto h = SolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR; - hipsolverFillMode_t uplo = - lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER; - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsolverSsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr, - /*lda=*/n, /*W=*/nullptr, &lwork))); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsolverDsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr, - /*lda=*/n, /*W=*/nullptr, &lwork))); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsolverCheevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr, - /*lda=*/n, /*W=*/nullptr, &lwork))); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - } - return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})}; -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -// Returns the workspace size and a descriptor for a syevj_batched operation. -std::pair BuildSyevjDescriptor(const py::dtype& dtype, - bool lower, int batch, int n) { - HipsolverType type = DtypeToHipsolverType(dtype); - auto h = SolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - hipsolverSyevjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); }); - hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR; - hipsolverFillMode_t uplo = - lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER; - if (batch == 1) { - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - } - } else { - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - } - } - return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})}; -} - -// Singular value decomposition using QR algorithm: gesvd - -// Returns the workspace size and a descriptor for a gesvd operation. -std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, - int m, int n, bool compute_uv, - bool full_matrices) { - HipsolverType type = DtypeToHipsolverType(dtype); - auto h = SolverHandlePool::Borrow(); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - signed char jobu, jobvt; - if (compute_uv) { - if (full_matrices) { - jobu = jobvt = 'A'; - } else { - jobu = jobvt = 'S'; - } - } else { - jobu = jobvt = 'N'; - } - switch (type) { - case HipsolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsolverSgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case HipsolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsolverDgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case HipsolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsolverCgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case HipsolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS( - hipsolverZgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork))); - break; - } - return {lwork, - PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})}; -} - -py::dict Registrations() { - py::dict dict; - dict["hipsolver_potrf"] = EncapsulateFunction(Potrf); - dict["hipsolver_getrf"] = EncapsulateFunction(Getrf); - dict["hipsolver_geqrf"] = EncapsulateFunction(Geqrf); - dict["hipsolver_orgqr"] = EncapsulateFunction(Orgqr); - dict["hipsolver_syevd"] = EncapsulateFunction(Syevd); - dict["hipsolver_syevj"] = EncapsulateFunction(Syevj); - dict["hipsolver_gesvd"] = EncapsulateFunction(Gesvd); - // dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); not supported by - // ROCm yet - return dict; -} - -PYBIND11_MODULE(_hipsolver, m) { - m.def("registrations", &Registrations); - m.def("build_potrf_descriptor", &BuildPotrfDescriptor); - m.def("build_getrf_descriptor", &BuildGetrfDescriptor); - m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); - m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); - m.def("build_syevd_descriptor", &BuildSyevdDescriptor); - m.def("build_syevj_descriptor", &BuildSyevjDescriptor); - m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); - // m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); not supported by - // ROCm yet -} - -} // namespace -} // namespace jax diff --git a/jaxlib/rocm/hipsolver_kernels.cc b/jaxlib/rocm/hipsolver_kernels.cc deleted file mode 100644 index 9f5f14128..000000000 --- a/jaxlib/rocm/hipsolver_kernels.cc +++ /dev/null @@ -1,721 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/rocm/hipsolver_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "jaxlib/handle_pool.h" -#include "jaxlib/rocm/hip_gpu_kernel_helpers.h" -#include "jaxlib/kernel_helpers.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipsolver.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { - -template <> -/*static*/ absl::StatusOr -SolverHandlePool::Borrow(hipStream_t stream) { - SolverHandlePool* pool = Instance(); - absl::MutexLock lock(&pool->mu_); - hipsolverHandle_t handle; - if (pool->handles_[stream].empty()) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreate(&handle))); - } else { - handle = pool->handles_[stream].back(); - pool->handles_[stream].pop_back(); - } - if (stream) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSetStream(handle, stream))); - } - return Handle(pool, handle, stream); -} - -static int SizeOfHipsolverType(HipsolverType type) { - switch (type) { - case HipsolverType::F32: - return sizeof(float); - case HipsolverType::F64: - return sizeof(double); - case HipsolverType::C64: - return sizeof(hipFloatComplex); - case HipsolverType::C128: - return sizeof(hipDoubleComplex); - } -} - -// potrf: Cholesky decomposition - -static absl::Status Potrf_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const PotrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipMemcpyAsync(buffers[1], buffers[0], - SizeOfHipsolverType(d.type) * d.batch * d.n * d.n, - hipMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[2]); - void* workspace = buffers[3]; - if (d.batch == 1) { - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[1]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverSpotrf(handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[1]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverDpotrf(handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[1]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrf( - handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[1]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrf( - handle.get(), d.uplo, d.n, a, d.n, - static_cast(workspace), d.lwork, info))); - break; - } - } - } else { - auto buffer_ptrs_host = - MakeBatchPointers(stream, buffers[1], workspace, d.batch, - SizeOfHipsolverType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(buffer_ptrs_host.status()); - // Make sure that accesses to buffer_ptrs_host complete before we delete it. - // TODO(phawkins): avoid synchronization here. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream))); - switch (d.type) { - case HipsolverType::F32: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSpotrfBatched( - handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - reinterpret_cast(static_cast(workspace) + d.batch), - d.lwork, info, d.batch))); - break; - } - case HipsolverType::F64: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDpotrfBatched( - handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - reinterpret_cast(static_cast(workspace) + d.batch), - d.lwork, info, d.batch))); - break; - } - case HipsolverType::C64: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrfBatched( - handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - reinterpret_cast(static_cast(workspace) + - d.batch), d.lwork, info, d.batch))); - break; - } - case HipsolverType::C128: { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched( - handle.get(), d.uplo, d.n, static_cast(workspace), d.n, - reinterpret_cast(static_cast(workspace) + - d.batch), d.lwork, info, d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Potrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Potrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// getrf: LU decomposition - -static absl::Status Getrf_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[1], buffers[0], - SizeOfHipsolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - hipMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverSgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverDgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverCgetrf(handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgetrf( - handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Getrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Getrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// geqrf: QR decomposition - -static absl::Status Geqrf_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[1], buffers[0], - SizeOfHipsolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - hipMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - - void* workspace = buffers[4]; - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[1]); - float* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[1]); - double* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[1]); - hipFloatComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[1]); - hipDoubleComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Geqrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Geqrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// orgqr/ungqr: apply elementary Householder transformations - -static absl::Status Orgqr_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const OrgqrDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[2], buffers[0], - SizeOfHipsolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - hipMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - - void* workspace = buffers[4]; - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[2]); - float* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[2]); - double* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[2]); - hipFloatComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[2]); - hipDoubleComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Orgqr(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Orgqr_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -static absl::Status Syevd_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[1], buffers[0], - SizeOfHipsolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.n), - hipMemcpyDeviceToDevice, stream))); - hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[3]); - void* work = buffers[4]; - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Syevd(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevjDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[1], buffers[0], - SizeOfHipsolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.n), - hipMemcpyDeviceToDevice, stream))); - } - hipsolverSyevjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); }); - hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[3]); - void* work = buffers[4]; - if (d.batch == 1) { - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - } - } else { - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), - d.lwork, info, params, d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Syevj(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevj_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Singular value decomposition using QR algorithm: gesvd - -static absl::Status Gesvd_(hipStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GesvdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync( - buffers[1], buffers[0], - SizeOfHipsolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - hipMemcpyDeviceToDevice, stream))); - int* info = static_cast(buffers[5]); - void* work = buffers[6]; - switch (d.type) { - case HipsolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - hipsolverSgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, - u, d.m, vt, d.n, static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * d.m; - vt += d.n * d.n; - ++info; - } - break; - } - case HipsolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * d.m; - vt += d.n * d.n; - ++info; - } - break; - } - case HipsolverType::C64: { - hipFloatComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - hipFloatComplex* u = static_cast(buffers[3]); - hipFloatComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * d.m; - vt += d.n * d.n; - ++info; - } - break; - } - case HipsolverType::C128: { - hipDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - hipDoubleComplex* u = static_cast(buffers[3]); - hipDoubleComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * d.m; - vt += d.n * d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Gesvd(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Gesvd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// TODO(rocm): add Gesvdj_ apis when support from hipsolver is ready -} // namespace jax diff --git a/jaxlib/rocm/hipsolver_kernels.h b/jaxlib/rocm/hipsolver_kernels.h deleted file mode 100644 index c21e94c92..000000000 --- a/jaxlib/rocm/hipsolver_kernels.h +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2021 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_HIPSOLVER_KERNELS_H_ -#define JAXLIB_HIPSOLVER_KERNELS_H_ - -#include "absl/status/statusor.h" -#include "jaxlib/handle_pool.h" -#include "rocm/include/hip/hip_runtime_api.h" -#include "rocm/include/hipblas.h" -#include "rocm/include/hipsolver.h" -#include "tensorflow/compiler/xla/service/custom_call_status.h" - -namespace jax { - -using SolverHandlePool = HandlePool; - -template <> -absl::StatusOr -SolverHandlePool::Borrow(hipStream_t stream); - -// Set of types known to Hipsolver. -enum class HipsolverType { - F32, - F64, - C64, - C128, -}; - -// potrf: Cholesky decomposition - -struct PotrfDescriptor { - HipsolverType type; - hipsolverFillMode_t uplo; - std::int64_t batch, n; - int lwork; -}; - -void Potrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -// getrf: LU decomposition - -struct GetrfDescriptor { - HipsolverType type; - int batch, m, n, lwork; -}; - -void Getrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// geqrf: QR decomposition - -struct GeqrfDescriptor { - HipsolverType type; - int batch, m, n, lwork; -}; - -void Geqrf(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// orgqr/ungqr: apply elementary Householder transformations - -struct OrgqrDescriptor { - HipsolverType type; - int batch, m, n, k, lwork; -}; - -void Orgqr(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -struct SyevdDescriptor { - HipsolverType type; - hipsolverFillMode_t uplo; - int batch, n; - int lwork; -}; - -void Syevd(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -struct SyevjDescriptor { - HipsolverType type; - hipsolverFillMode_t uplo; - int batch, n; - int lwork; -}; - -void Syevj(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Singular value decomposition using QR algorithm: gesvd - -struct GesvdDescriptor { - HipsolverType type; - int batch, m, n; - int lwork; - signed char jobu, jobvt; -}; - -void Gesvd(hipStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace jax - -#endif // JAXLIB_HIPSOLVER_KERNELS_H_