diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 5cf85f369..34e40d12d 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -227,6 +227,22 @@ cc_library( ], ) +cc_library( + name = "cusolver_interface", + srcs = ["//jaxlib/gpu:solver_interface.cc"], + hdrs = ["//jaxlib/gpu:solver_interface.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", + ], +) + cc_library( name = "cusolver_kernels_ffi", srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], @@ -237,6 +253,7 @@ cc_library( ":cuda_make_batch_pointers", ":cuda_solver_handle_pool", ":cuda_vendor", + ":cusolver_interface", "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 8c4144974..048ea23a9 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -53,6 +53,8 @@ exports_files(srcs = [ "solver.cc", "solver_handle_pool.cc", "solver_handle_pool.h", + "solver_interface.cc", + "solver_interface.h", "solver_kernels.cc", "solver_kernels.h", "solver_kernels_ffi.cc", diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc new file mode 100644 index 000000000..3c8282ec6 --- /dev/null +++ b/jaxlib/gpu/solver_interface.cc @@ -0,0 +1,237 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/solver_interface.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace solver { + +// LU decomposition: getrf + +#define JAX_GPU_DEFINE_GETRF(Type, Name) \ + template <> \ + absl::StatusOr GetrfBufferSize(gpusolverDnHandle_t handle, int m, \ + int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Getrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \ + Type *workspace, int lwork, int *ipiv, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \ + } + +JAX_GPU_DEFINE_GETRF(float, gpusolverDnSgetrf); +JAX_GPU_DEFINE_GETRF(double, gpusolverDnDgetrf); +JAX_GPU_DEFINE_GETRF(gpuComplex, gpusolverDnCgetrf); +JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf); +#undef JAX_GPU_DEFINE_GETRF + +#define JAX_GPU_DEFINE_GETRF_BATCHED(Type, Name) \ + template <> \ + absl::Status GetrfBatched(gpublasHandle_t handle, int n, Type **a, \ + int lda, int *ipiv, int *info, int batch) { \ + return JAX_AS_STATUS(Name(handle, n, a, lda, ipiv, info, batch)); \ + } + +JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched); +#undef JAX_GPU_DEFINE_GETRF_BATCHED + +// QR decomposition: geqrf + +#define JAX_GPU_DEFINE_GEQRF(Type, Name) \ + template <> \ + absl::StatusOr GeqrfBufferSize(gpusolverDnHandle_t handle, int m, \ + int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Geqrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \ + Type *tau, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, a, m, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_GEQRF(float, gpusolverDnSgeqrf); +JAX_GPU_DEFINE_GEQRF(double, gpusolverDnDgeqrf); +JAX_GPU_DEFINE_GEQRF(gpuComplex, gpusolverDnCgeqrf); +JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf); +#undef JAX_GPU_DEFINE_GEQRF + +#define JAX_GPU_DEFINE_GEQRF_BATCHED(Type, Name) \ + template <> \ + absl::Status GeqrfBatched(gpublasHandle_t handle, int m, int n, \ + Type **a, Type **tau, int *info, \ + int batch) { \ + return JAX_AS_STATUS(Name(handle, m, n, a, m, tau, info, batch)); \ + } + +JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched); +#undef JAX_GPU_DEFINE_GEQRF_BATCHED + +// Householder transformations: orgqr + +#define JAX_GPU_DEFINE_ORGQR(Type, Name) \ + template <> \ + absl::StatusOr OrgqrBufferSize(gpusolverDnHandle_t handle, int m, \ + int n, int k) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, m, n, k, /*A=*/nullptr, /*lda=*/m, /*tau=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Orgqr(gpusolverDnHandle_t handle, int m, int n, int k, \ + Type *a, Type *tau, Type *workspace, int lwork, \ + int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_ORGQR(float, gpusolverDnSorgqr); +JAX_GPU_DEFINE_ORGQR(double, gpusolverDnDorgqr); +JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr); +JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr); +#undef JAX_GPU_DEFINE_ORGQR + +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd + +#define JAX_GPU_DEFINE_SYEVJ(Type, Name) \ + template <> \ + absl::StatusOr SyevjBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Syevj( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info, gpuSyevjInfo_t params) { \ + return JAX_AS_STATUS( \ + Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info, params)); \ + } + +JAX_GPU_DEFINE_SYEVJ(float, gpusolverDnSsyevj); +JAX_GPU_DEFINE_SYEVJ(double, gpusolverDnDsyevj); +JAX_GPU_DEFINE_SYEVJ(gpuComplex, gpusolverDnCheevj); +JAX_GPU_DEFINE_SYEVJ(gpuDoubleComplex, gpusolverDnZheevj); +#undef JAX_GPU_DEFINE_SYEVJ + +#define JAX_GPU_DEFINE_SYEVJ_BATCHED(Type, Name) \ + template <> \ + absl::StatusOr SyevjBatchedBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params, batch))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status SyevjBatched( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info, gpuSyevjInfo_t params, \ + int batch) { \ + return JAX_AS_STATUS(Name(handle, jobz, uplo, n, a, n, w, workspace, \ + lwork, info, params, batch)); \ + } + +JAX_GPU_DEFINE_SYEVJ_BATCHED(float, gpusolverDnSsyevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(double, gpusolverDnDsyevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuComplex, gpusolverDnCheevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuDoubleComplex, gpusolverDnZheevjBatched); +#undef JAX_GPU_DEFINE_SYEVJ_BATCHED + +#define JAX_GPU_DEFINE_SYEVD(Type, Name) \ + template <> \ + absl::StatusOr SyevdBufferSize(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR( \ + JAX_AS_STATUS(Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, \ + /*lda=*/n, /*w=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Syevd(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ + int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_SYEVD(float, gpusolverDnSsyevd); +JAX_GPU_DEFINE_SYEVD(double, gpusolverDnDsyevd); +JAX_GPU_DEFINE_SYEVD(gpuComplex, gpusolverDnCheevd); +JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd); +#undef JAX_GPU_DEFINE_SYEVD + +// Symmetric rank-k update: syrk + +#define JAX_GPU_DEFINE_SYRK(Type, Name) \ + template <> \ + absl::Status Syrk(gpublasHandle_t handle, gpublasFillMode_t uplo, \ + gpublasOperation_t trans, int n, int k, \ + const Type *alpha, const Type *a, const Type *beta, \ + Type *c) { \ + int lda = trans == GPUBLAS_OP_N ? n : k; \ + return JAX_AS_STATUS( \ + Name(handle, uplo, trans, n, k, alpha, a, lda, beta, c, n)); \ + } + +JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk); +JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk); +JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); +JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); +#undef JAX_GPU_DEFINE_SYRK + +} // namespace solver +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h new file mode 100644 index 000000000..5072be984 --- /dev/null +++ b/jaxlib/gpu/solver_interface.h @@ -0,0 +1,174 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines a standard interface to the GPU linear algebra libraries. + +#ifndef JAXLIB_GPU_SOLVER_INTERFACE_H_ +#define JAXLIB_GPU_SOLVER_INTERFACE_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace solver { + +template +struct RealType { + using value = T; +}; + +template <> +struct RealType { + using value = float; +}; + +template <> +struct RealType { + using value = double; +}; + +#define JAX_GPU_SOLVER_EXPAND_DEFINITION(ReturnType, FunctionName) \ + template \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(T, typename RealType::value)) { \ + return absl::UnimplementedError(absl::StrFormat( \ + #FunctionName " not implemented for type %s", typeid(T).name())); \ + } \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(float, float)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(double, double)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuComplex, float)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuDoubleComplex, double)) + +// LU decomposition: getrf + +#define JAX_GPU_SOLVER_GetrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GetrfBufferSize); +#undef JAX_GPU_SOLVER_GetrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Getrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, Type *a, Type *workspace, \ + int lwork, int *ipiv, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Getrf); +#undef JAX_GPU_SOLVER_Getrf_ARGS + +#define JAX_GPU_SOLVER_GetrfBatched_ARGS(Type, ...) \ + gpublasHandle_t handle, int n, Type **a, int lda, int *ipiv, int *info, \ + int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GetrfBatched); +#undef JAX_GPU_SOLVER_GetrfBatched_ARGS + +// QR decomposition: geqrf + +#define JAX_GPU_SOLVER_GeqrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GeqrfBufferSize); +#undef JAX_GPU_SOLVER_GeqrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Geqrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, Type *a, Type *tau, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Geqrf); +#undef JAX_GPU_SOLVER_Geqrf_ARGS + +#define JAX_GPU_SOLVER_GeqrfBatched_ARGS(Type, ...) \ + gpublasHandle_t handle, int m, int n, Type **a, Type **tau, int *info, \ + int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GeqrfBatched); +#undef JAX_GPU_SOLVER_GeqrfBatched_ARGS + +// Householder transformations: orgqr + +#define JAX_GPU_SOLVER_OrgqrBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, int k +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, OrgqrBufferSize); +#undef JAX_GPU_SOLVER_OrgqrBufferSize_ARGS + +#define JAX_GPU_SOLVER_Orgqr_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, int k, Type *a, Type *tau, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr); +#undef JAX_GPU_SOLVER_Orgqr_ARGS + +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd + +#define JAX_GPU_SOLVER_SyevjBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBufferSize); +#undef JAX_GPU_SOLVER_SyevjBufferSize_ARGS + +#define JAX_GPU_SOLVER_Syevj_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info, gpuSyevjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevj); +#undef JAX_GPU_SOLVER_Syevj_ARGS + +#define JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBatchedBufferSize); +#undef JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS + +#define JAX_GPU_SOLVER_SyevjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info, gpuSyevjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, SyevjBatched); +#undef JAX_GPU_SOLVER_SyevjBatched_ARGS + +#define JAX_GPU_SOLVER_SyevdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevdBufferSize); +#undef JAX_GPU_SOLVER_SyevdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Syevd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd); +#undef JAX_GPU_SOLVER_Syevd_ARGS + +// Symmetric rank-k update: syrk + +#define JAX_GPU_SOLVER_Syrk_ARGS(Type, ...) \ + gpublasHandle_t handle, gpublasFillMode_t uplo, gpublasOperation_t trans, \ + int n, int k, const Type *alpha, const Type *a, const Type *beta, \ + Type *c +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk); +#undef JAX_GPU_SOLVER_Syrk_ARGS + +#undef JAX_GPU_SOLVER_EXPAND_DEFINITION + +} // namespace solver +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_SOLVER_INTERFACE_H_ diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 3c74b8519..e3f63234f 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -29,9 +29,13 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/solver_handle_pool.h" +#include "jaxlib/gpu/solver_interface.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" +#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) + XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::JAX_GPU_NAMESPACE::SyevdAlgorithm); namespace jax { @@ -39,7 +43,6 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = ::xla::ffi; -namespace { template inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, int64_t size, @@ -53,22 +56,6 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, return static_cast(maybe_workspace.value()); } -template -struct RealType { - using Type = T; -}; - -template <> -struct RealType { - using Type = float; -}; - -template <> -struct RealType { - using Type = double; -}; -} // namespace - #define SOLVER_DISPATCH_IMPL(impl, ...) \ if (dataType == ffi::F32) { \ return impl(__VA_ARGS__); \ @@ -93,33 +80,6 @@ struct RealType { // LU decomposition: getrf -namespace { -#define GETRF_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ - int n) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \ - type* workspace, int lwork, int* ipiv, \ - int* info) { \ - return JAX_AS_STATUS( \ - name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \ - } \ - } - -template -struct GetrfKernel; -GETRF_KERNEL_IMPL(float, gpusolverDnSgetrf); -GETRF_KERNEL_IMPL(double, gpusolverDnDgetrf); -GETRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgetrf); -GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf); -#undef GETRF_KERNEL_IMPL - template ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -131,7 +91,7 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, - GetrfKernel::BufferSize(handle.get(), m, n)); + solver::GetrfBufferSize(handle.get(), m, n)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "getrf")); @@ -140,13 +100,13 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } int ipiv_step = std::min(m, n); for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(GetrfKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::Getrf( handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data)); out_data += m * n; ipiv_data += ipiv_step; @@ -155,23 +115,6 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, return ffi::Error::Success(); } -#define GETRF_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfBatchedKernel { \ - static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \ - int* ipiv, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \ - } \ - } - -template -struct GetrfBatchedKernel; -GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched); -#undef GETRF_BATCHED_KERNEL_IMPL - template ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, @@ -188,15 +131,15 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, sizeof(T) * n * n); - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); - FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::GetrfBatched( handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); return ffi::Error::Success(); @@ -228,7 +171,6 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in getrf", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch, ffi::Ffi::Bind() @@ -242,33 +184,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch, // QR decomposition: geqrf -namespace { -#define GEQRF_KERNEL_IMPL(type, name) \ - template <> \ - struct GeqrfKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ - int n) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \ - type* tau, type* workspace, int lwork, \ - int* info) { \ - return JAX_AS_STATUS( \ - name(handle, m, n, a, m, tau, workspace, lwork, info)); \ - } \ - } - -template -struct GeqrfKernel; -GEQRF_KERNEL_IMPL(float, gpusolverDnSgeqrf); -GEQRF_KERNEL_IMPL(double, gpusolverDnDgeqrf); -GEQRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgeqrf); -GEQRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgeqrf); -#undef GEQRF_KERNEL_IMPL - template ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -279,7 +194,7 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, - GeqrfKernel::BufferSize(handle.get(), m, n)); + solver::GeqrfBufferSize(handle.get(), m, n)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "geqrf")); @@ -292,14 +207,14 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = static_cast(out->untyped_data()); auto tau_data = static_cast(tau->untyped_data()); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } int out_step = m * n; int tau_step = std::min(m, n); for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::Geqrf( handle.get(), m, n, out_data, tau_data, workspace, lwork, info)); out_data += out_step; tau_data += tau_step; @@ -307,23 +222,6 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, return ffi::Error::Success(); } -#define GEQRF_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct GeqrfBatchedKernel { \ - static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \ - type** tau, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \ - } \ - } - -template -struct GeqrfBatchedKernel; -GEQRF_BATCHED_KERNEL_IMPL(float, gpublasSgeqrfBatched); -GEQRF_BATCHED_KERNEL_IMPL(double, gpublasDgeqrfBatched); -GEQRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgeqrfBatched); -GEQRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgeqrfBatched); -#undef GEQRF_BATCHED_KERNEL_IMPL - template ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -341,21 +239,21 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = out->untyped_data(); auto tau_data = tau->untyped_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch, sizeof(T) * m * n); - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch, sizeof(T) * std::min(m, n)); - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); // We ignore the output value of `info` because it is only used for shape // checking. int info; - FFI_RETURN_IF_ERROR_STATUS(GeqrfBatchedKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::GeqrfBatched( handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch)); return ffi::Error::Success(); @@ -385,7 +283,6 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in geqrf", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, ffi::Ffi::Bind() @@ -398,34 +295,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, // Householder transformations: orgqr -namespace { -#define ORGQR_KERNEL_IMPL(type, name) \ - template <> \ - struct OrgqrKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ - int n, int k) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \ - /*tau=*/nullptr, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \ - type* a, type* tau, type* workspace, int lwork, \ - int* info) { \ - return JAX_AS_STATUS( \ - name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \ - } \ - } - -template -struct OrgqrKernel; -ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr); -ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr); -ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr); -ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr); -#undef ORGQR_KERNEL_IMPL - template ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, gpuStream_t stream, ffi::ScratchAllocator& scratch, @@ -437,7 +306,7 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, - OrgqrKernel::BufferSize(handle.get(), m, n, k)); + solver::OrgqrBufferSize(handle.get(), m, n, k)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "orgqr")); @@ -450,13 +319,13 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, auto tau_data = static_cast(tau.untyped_data()); auto out_data = static_cast(out->untyped_data()); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } int out_step = m * n; for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel::Run( + FFI_RETURN_IF_ERROR_STATUS(solver::Orgqr( handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info)); out_data += out_step; tau_data += k; @@ -492,7 +361,6 @@ ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in orgqr", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, ffi::Ffi::Bind() @@ -510,98 +378,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, // dispatches dynamically to both syevd and syevj depending on the problem // size and the algorithm selected by the user via the `algorithm` attribute. -namespace { -#define SYEVJ_KERNEL_IMPL(type, name) \ - template <> \ - struct SyevjKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, \ - gpusolverFillMode_t uplo, int n, \ - gpuSyevjInfo_t params) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ - /*w=*/nullptr, &lwork, params))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ - int n, type* a, RealType::Type* w, \ - type* workspace, int lwork, int* info, \ - gpuSyevjInfo_t params) { \ - return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \ - lwork, info, params)); \ - } \ - } - -template -struct SyevjKernel; -SYEVJ_KERNEL_IMPL(float, gpusolverDnSsyevj); -SYEVJ_KERNEL_IMPL(double, gpusolverDnDsyevj); -SYEVJ_KERNEL_IMPL(gpuComplex, gpusolverDnCheevj); -SYEVJ_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevj); -#undef SYEVJ_KERNEL_IMPL - -#define SYEVJ_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct SyevjBatchedKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, \ - gpusolverFillMode_t uplo, int n, \ - gpuSyevjInfo_t params, int batch) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ - /*w=*/nullptr, &lwork, params, batch))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ - int n, type* a, RealType::Type* w, \ - type* workspace, int lwork, int* info, \ - gpuSyevjInfo_t params, int batch) { \ - return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \ - lwork, info, params, batch)); \ - } \ - } - -template -struct SyevjBatchedKernel; -SYEVJ_BATCHED_KERNEL_IMPL(float, gpusolverDnSsyevjBatched); -SYEVJ_BATCHED_KERNEL_IMPL(double, gpusolverDnDsyevjBatched); -SYEVJ_BATCHED_KERNEL_IMPL(gpuComplex, gpusolverDnCheevjBatched); -SYEVJ_BATCHED_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevjBatched); -#undef SYEVJ_BATCHED_KERNEL_IMPL - -#define SYEVD_KERNEL_IMPL(type, name) \ - template <> \ - struct SyevdKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, \ - gpusolverFillMode_t uplo, int n) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ - /*w=*/nullptr, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, \ - gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ - int n, type* a, RealType::Type* w, \ - type* workspace, int lwork, int* info) { \ - return JAX_AS_STATUS( \ - name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \ - } \ - } - -template -struct SyevdKernel; -SYEVD_KERNEL_IMPL(float, gpusolverDnSsyevd); -SYEVD_KERNEL_IMPL(double, gpusolverDnDsyevd); -SYEVD_KERNEL_IMPL(gpuComplex, gpusolverDnCheevd); -SYEVD_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevd); -#undef SYEVD_KERNEL_IMPL - template ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm, @@ -618,49 +394,48 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); - auto w_data = static_cast::Type*>(w->untyped_data()); + auto w_data = static_cast::value*>(w->untyped_data()); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } if (algorithm == SyevdAlgorithm::kJacobi || (algorithm == SyevdAlgorithm::kDefault && size <= 32)) { gpuSyevjInfo_t params; - FFI_RETURN_IF_ERROR_STATUS( - JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms)); std::unique_ptr params_cleanup( params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); if (batch == 1) { - FFI_ASSIGN_OR_RETURN(int lwork, SyevjKernel::BufferSize( + FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize( handle.get(), jobz, uplo, n, params)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "syevj")); - FFI_RETURN_IF_ERROR_STATUS( - SyevjKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data, - workspace, lwork, info_data, params)); + FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data, params)); } else { FFI_ASSIGN_OR_RETURN( - int lwork, SyevjBatchedKernel::BufferSize(handle.get(), jobz, uplo, + int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo, n, params, batch)); FFI_ASSIGN_OR_RETURN( auto workspace, AllocateWorkspace(scratch, lwork, "syevj_batched")); - FFI_RETURN_IF_ERROR_STATUS(SyevjBatchedKernel::Run( - handle.get(), jobz, uplo, n, out_data, w_data, workspace, lwork, - info_data, params, batch)); + FFI_RETURN_IF_ERROR_STATUS( + solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data, + workspace, lwork, info_data, params, batch)); } } else { FFI_ASSIGN_OR_RETURN( - int lwork, SyevdKernel::BufferSize(handle.get(), jobz, uplo, n)); + int lwork, solver::SyevdBufferSize(handle.get(), jobz, uplo, n)); FFI_ASSIGN_OR_RETURN(auto workspace, AllocateWorkspace(scratch, lwork, "syevd")); int out_step = n * n; for (auto i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS( - SyevdKernel::Run(handle.get(), jobz, uplo, n, out_data, w_data, - workspace, lwork, info_data)); + FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data)); out_data += out_step; w_data += n; ++info_data; @@ -695,7 +470,6 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, return ffi::Error::InvalidArgument(absl::StrFormat( "Unsupported dtype %s in syevd", absl::FormatStreamed(dataType))); } -} // namespace XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch, ffi::Ffi::Bind() @@ -709,110 +483,83 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch, .Ret>() // info ); -#define SYRK_KERNEL_IMPL(type, fn) \ - template <> \ - struct SyrkKernel { \ - static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \ - std::int64_t k, bool transpose, \ - const type* alpha, const type* beta, \ - const type* a_matrix, type* c_matrix) { \ - gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \ - gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \ - int lda = transpose ? n : k; \ - return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \ - alpha, a_matrix, lda, beta, \ - c_matrix, n)); \ - } \ - } +// Symmetric rank-k update: syrk template -struct SyrkKernel; - -SYRK_KERNEL_IMPL(float, gpublasSsyrk); -SYRK_KERNEL_IMPL(double, gpublasDsyrk); -SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk); -SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk); -#undef SYRK_KERNEL_IMPL - -template -ffi::Error SyrkImpl(gpuStream_t stream, - ffi::AnyBuffer a_matrix, - ffi::AnyBuffer c_matrix, - bool transpose, - ffi::AnyBuffer alpha, - ffi::AnyBuffer beta, - ffi::Result c_matrix_out) { +ffi::Error SyrkImpl(gpuStream_t stream, bool transpose, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, ffi::Result c_out) { FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), - SplitBatch2D(a_matrix.dimensions())); - FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]), - SplitBatch2D(c_matrix.dimensions())); - FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]), - SplitBatch2D(c_matrix_out->dimensions())); - if (batch != batch_c || batch != batch_out) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "a_matrix, c_matrix and c_matrix_out must have the same " - "batch size."); + SplitBatch2D(a.dimensions())); + if (alpha.element_count() != 1 || beta.element_count() != 1) { + return ffi::Error::InvalidArgument( + "The alpha and beta inputs to syrk must be scalars"); } - int n = transpose ? cols : rows; - int k = transpose ? rows : cols; - + auto size = transpose ? cols : rows; FFI_RETURN_IF_ERROR( - CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk")); + CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk")); FFI_RETURN_IF_ERROR( - CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk")); + CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk")); - const T* a_data = static_cast(a_matrix.untyped_data()); - T* c_data = static_cast(c_matrix.untyped_data()); - T* c_out_data = static_cast(c_matrix_out->untyped_data()); + FFI_ASSIGN_OR_RETURN(auto n, + MaybeCastNoOverflow(transpose ? cols : rows)); + FFI_ASSIGN_OR_RETURN(auto k, + MaybeCastNoOverflow(transpose ? rows : cols)); + gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; + gpublasOperation_t trans = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; + + const T* a_data = static_cast(a.untyped_data()); + T* c_data = static_cast(c_in.untyped_data()); + T* c_out_data = static_cast(c_out->untyped_data()); // with alpha or beta provided as device_pointers, cublassyrk will SIGSEGV T host_alpha; - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - &host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, - stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_alpha, alpha.untyped_data(), + sizeof(T), gpuMemcpyDeviceToHost, + stream)); T host_beta; - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - &host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost, - stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_beta, beta.untyped_data(), + sizeof(T), gpuMemcpyDeviceToHost, + stream)); if (c_data != c_out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( - c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice, - stream))); + JAX_FFI_RETURN_IF_GPU_ERROR( + gpuMemcpyAsync(c_out_data, c_data, c_in.size_bytes(), + gpuMemcpyDeviceToDevice, stream)); } FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); for (int i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(SyrkKernel::Run( - handle.get(), n, k, transpose, &host_alpha, &host_beta, - a_data + i * k * n, c_out_data + i * n * n)); + FFI_RETURN_IF_ERROR_STATUS(solver::Syrk(handle.get(), uplo, trans, n, k, + &host_alpha, a_data, &host_beta, + c_out_data)); + a_data += k * n; + c_out_data += n * n; } return ffi::Error::Success(); } -ffi::Error SyrkDispatch( - gpuStream_t stream, - ffi::AnyBuffer a_matrix, - ffi::AnyBuffer c_matrix, - bool transpose, - ffi::AnyBuffer alpha, - ffi::AnyBuffer beta, - ffi::Result c_matrix_out) { - auto dataType = a_matrix.element_type(); - SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose, - alpha, beta, c_matrix_out); - return ffi::Error::InvalidArgument("Unsupported element type for Syrk"); +ffi::Error SyrkDispatch(gpuStream_t stream, bool transpose, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, + ffi::Result c_out) { + auto dataType = a.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, a, c_in, alpha, beta, + c_out); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in syrk", absl::FormatStreamed(dataType))); } XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, ffi::Ffi::Bind() .Ctx>() - .Arg() // a_matrix - .Arg() // c_matrix .Attr("transpose") // transpose - .Arg() // alpha - .Arg() // beta - .Ret()); // c_matrix_out + .Arg() // a + .Arg() // c_in + .Arg() // alpha + .Arg() // beta + .Ret() // c_out +); #undef SOLVER_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index ce856ae5f..598741522 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -168,6 +168,21 @@ cc_library( ], ) +cc_library( + name = "hipsolver_interface", + srcs = ["//jaxlib/gpu:solver_interface.cc"], + hdrs = ["//jaxlib/gpu:solver_interface.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hipblas", + "@local_config_rocm//rocm:hipsolver", + ], +) + cc_library( name = "hipsolver_kernels_ffi", srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], @@ -178,6 +193,7 @@ cc_library( ":hip_make_batch_pointers", ":hip_solver_handle_pool", ":hip_vendor", + ":hipsolver_interface", "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor",