Refactor gpusolver kernel definitions into separate build target.

There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.

PiperOrigin-RevId: 673832309
This commit is contained in:
Dan Foreman-Mackey 2024-09-12 07:11:01 -07:00 committed by jax authors
parent 2067da818e
commit a3bf75e442
6 changed files with 537 additions and 344 deletions

View File

@ -227,6 +227,22 @@ cc_library(
],
)
cc_library(
name = "cusolver_interface",
srcs = ["//jaxlib/gpu:solver_interface.cc"],
hdrs = ["//jaxlib/gpu:solver_interface.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
],
)
cc_library(
name = "cusolver_kernels_ffi",
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
@ -237,6 +253,7 @@ cc_library(
":cuda_make_batch_pointers",
":cuda_solver_handle_pool",
":cuda_vendor",
":cusolver_interface",
"//jaxlib:ffi_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",

View File

@ -53,6 +53,8 @@ exports_files(srcs = [
"solver.cc",
"solver_handle_pool.cc",
"solver_handle_pool.h",
"solver_interface.cc",
"solver_interface.h",
"solver_kernels.cc",
"solver_kernels.h",
"solver_kernels_ffi.cc",

View File

@ -0,0 +1,237 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/solver_interface.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace solver {
// LU decomposition: getrf
#define JAX_GPU_DEFINE_GETRF(Type, Name) \
template <> \
absl::StatusOr<int> GetrfBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Getrf<Type>(gpusolverDnHandle_t handle, int m, int n, Type *a, \
Type *workspace, int lwork, int *ipiv, int *info) { \
return JAX_AS_STATUS( \
Name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
}
JAX_GPU_DEFINE_GETRF(float, gpusolverDnSgetrf);
JAX_GPU_DEFINE_GETRF(double, gpusolverDnDgetrf);
JAX_GPU_DEFINE_GETRF(gpuComplex, gpusolverDnCgetrf);
JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf);
#undef JAX_GPU_DEFINE_GETRF
#define JAX_GPU_DEFINE_GETRF_BATCHED(Type, Name) \
template <> \
absl::Status GetrfBatched<Type>(gpublasHandle_t handle, int n, Type **a, \
int lda, int *ipiv, int *info, int batch) { \
return JAX_AS_STATUS(Name(handle, n, a, lda, ipiv, info, batch)); \
}
JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched);
JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched);
JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched);
JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched);
#undef JAX_GPU_DEFINE_GETRF_BATCHED
// QR decomposition: geqrf
#define JAX_GPU_DEFINE_GEQRF(Type, Name) \
template <> \
absl::StatusOr<int> GeqrfBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Geqrf<Type>(gpusolverDnHandle_t handle, int m, int n, Type *a, \
Type *tau, Type *workspace, int lwork, int *info) { \
return JAX_AS_STATUS( \
Name(handle, m, n, a, m, tau, workspace, lwork, info)); \
}
JAX_GPU_DEFINE_GEQRF(float, gpusolverDnSgeqrf);
JAX_GPU_DEFINE_GEQRF(double, gpusolverDnDgeqrf);
JAX_GPU_DEFINE_GEQRF(gpuComplex, gpusolverDnCgeqrf);
JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf);
#undef JAX_GPU_DEFINE_GEQRF
#define JAX_GPU_DEFINE_GEQRF_BATCHED(Type, Name) \
template <> \
absl::Status GeqrfBatched<Type>(gpublasHandle_t handle, int m, int n, \
Type **a, Type **tau, int *info, \
int batch) { \
return JAX_AS_STATUS(Name(handle, m, n, a, m, tau, info, batch)); \
}
JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched);
JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched);
JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched);
JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched);
#undef JAX_GPU_DEFINE_GEQRF_BATCHED
// Householder transformations: orgqr
#define JAX_GPU_DEFINE_ORGQR(Type, Name) \
template <> \
absl::StatusOr<int> OrgqrBufferSize<Type>(gpusolverDnHandle_t handle, int m, \
int n, int k) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \
handle, m, n, k, /*A=*/nullptr, /*lda=*/m, /*tau=*/nullptr, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Orgqr<Type>(gpusolverDnHandle_t handle, int m, int n, int k, \
Type *a, Type *tau, Type *workspace, int lwork, \
int *info) { \
return JAX_AS_STATUS( \
Name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
}
JAX_GPU_DEFINE_ORGQR(float, gpusolverDnSorgqr);
JAX_GPU_DEFINE_ORGQR(double, gpusolverDnDorgqr);
JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr);
JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr);
#undef JAX_GPU_DEFINE_ORGQR
// Symmetric (Hermitian) eigendecomposition:
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
// * QR algorithm: syevd/heevd
#define JAX_GPU_DEFINE_SYEVJ(Type, Name) \
template <> \
absl::StatusOr<int> SyevjBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
/*w=*/nullptr, &lwork, params))); \
return lwork; \
} \
\
template <> \
absl::Status Syevj<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, Type *a, RealType<Type>::value *w, \
Type *workspace, int lwork, int *info, gpuSyevjInfo_t params) { \
return JAX_AS_STATUS( \
Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info, params)); \
}
JAX_GPU_DEFINE_SYEVJ(float, gpusolverDnSsyevj);
JAX_GPU_DEFINE_SYEVJ(double, gpusolverDnDsyevj);
JAX_GPU_DEFINE_SYEVJ(gpuComplex, gpusolverDnCheevj);
JAX_GPU_DEFINE_SYEVJ(gpuDoubleComplex, gpusolverDnZheevj);
#undef JAX_GPU_DEFINE_SYEVJ
#define JAX_GPU_DEFINE_SYEVJ_BATCHED(Type, Name) \
template <> \
absl::StatusOr<int> SyevjBatchedBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
/*w=*/nullptr, &lwork, params, batch))); \
return lwork; \
} \
\
template <> \
absl::Status SyevjBatched<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, Type *a, RealType<Type>::value *w, \
Type *workspace, int lwork, int *info, gpuSyevjInfo_t params, \
int batch) { \
return JAX_AS_STATUS(Name(handle, jobz, uplo, n, a, n, w, workspace, \
lwork, info, params, batch)); \
}
JAX_GPU_DEFINE_SYEVJ_BATCHED(float, gpusolverDnSsyevjBatched);
JAX_GPU_DEFINE_SYEVJ_BATCHED(double, gpusolverDnDsyevjBatched);
JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuComplex, gpusolverDnCheevjBatched);
JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuDoubleComplex, gpusolverDnZheevjBatched);
#undef JAX_GPU_DEFINE_SYEVJ_BATCHED
#define JAX_GPU_DEFINE_SYEVD(Type, Name) \
template <> \
absl::StatusOr<int> SyevdBufferSize<Type>(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n) { \
int lwork; \
JAX_RETURN_IF_ERROR( \
JAX_AS_STATUS(Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, \
/*lda=*/n, /*w=*/nullptr, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Syevd<Type>(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
int n, Type *a, RealType<Type>::value *w, \
Type *workspace, int lwork, int *info) { \
return JAX_AS_STATUS( \
Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \
}
JAX_GPU_DEFINE_SYEVD(float, gpusolverDnSsyevd);
JAX_GPU_DEFINE_SYEVD(double, gpusolverDnDsyevd);
JAX_GPU_DEFINE_SYEVD(gpuComplex, gpusolverDnCheevd);
JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd);
#undef JAX_GPU_DEFINE_SYEVD
// Symmetric rank-k update: syrk
#define JAX_GPU_DEFINE_SYRK(Type, Name) \
template <> \
absl::Status Syrk<Type>(gpublasHandle_t handle, gpublasFillMode_t uplo, \
gpublasOperation_t trans, int n, int k, \
const Type *alpha, const Type *a, const Type *beta, \
Type *c) { \
int lda = trans == GPUBLAS_OP_N ? n : k; \
return JAX_AS_STATUS( \
Name(handle, uplo, trans, n, k, alpha, a, lda, beta, c, n)); \
}
JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk);
JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk);
JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk);
JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk);
#undef JAX_GPU_DEFINE_SYRK
} // namespace solver
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -0,0 +1,174 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines a standard interface to the GPU linear algebra libraries.
#ifndef JAXLIB_GPU_SOLVER_INTERFACE_H_
#define JAXLIB_GPU_SOLVER_INTERFACE_H_
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/vendor.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace solver {
template <typename T>
struct RealType {
using value = T;
};
template <>
struct RealType<gpuComplex> {
using value = float;
};
template <>
struct RealType<gpuDoubleComplex> {
using value = double;
};
#define JAX_GPU_SOLVER_EXPAND_DEFINITION(ReturnType, FunctionName) \
template <typename T> \
ReturnType FunctionName( \
JAX_GPU_SOLVER_##FunctionName##_ARGS(T, typename RealType<T>::value)) { \
return absl::UnimplementedError(absl::StrFormat( \
#FunctionName " not implemented for type %s", typeid(T).name())); \
} \
template <> \
ReturnType FunctionName<float>( \
JAX_GPU_SOLVER_##FunctionName##_ARGS(float, float)); \
template <> \
ReturnType FunctionName<double>( \
JAX_GPU_SOLVER_##FunctionName##_ARGS(double, double)); \
template <> \
ReturnType FunctionName<gpuComplex>( \
JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuComplex, float)); \
template <> \
ReturnType FunctionName<gpuDoubleComplex>( \
JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuDoubleComplex, double))
// LU decomposition: getrf
#define JAX_GPU_SOLVER_GetrfBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, int m, int n
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GetrfBufferSize);
#undef JAX_GPU_SOLVER_GetrfBufferSize_ARGS
#define JAX_GPU_SOLVER_Getrf_ARGS(Type, ...) \
gpusolverDnHandle_t handle, int m, int n, Type *a, Type *workspace, \
int lwork, int *ipiv, int *info
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Getrf);
#undef JAX_GPU_SOLVER_Getrf_ARGS
#define JAX_GPU_SOLVER_GetrfBatched_ARGS(Type, ...) \
gpublasHandle_t handle, int n, Type **a, int lda, int *ipiv, int *info, \
int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GetrfBatched);
#undef JAX_GPU_SOLVER_GetrfBatched_ARGS
// QR decomposition: geqrf
#define JAX_GPU_SOLVER_GeqrfBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, int m, int n
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GeqrfBufferSize);
#undef JAX_GPU_SOLVER_GeqrfBufferSize_ARGS
#define JAX_GPU_SOLVER_Geqrf_ARGS(Type, ...) \
gpusolverDnHandle_t handle, int m, int n, Type *a, Type *tau, \
Type *workspace, int lwork, int *info
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Geqrf);
#undef JAX_GPU_SOLVER_Geqrf_ARGS
#define JAX_GPU_SOLVER_GeqrfBatched_ARGS(Type, ...) \
gpublasHandle_t handle, int m, int n, Type **a, Type **tau, int *info, \
int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GeqrfBatched);
#undef JAX_GPU_SOLVER_GeqrfBatched_ARGS
// Householder transformations: orgqr
#define JAX_GPU_SOLVER_OrgqrBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, int m, int n, int k
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, OrgqrBufferSize);
#undef JAX_GPU_SOLVER_OrgqrBufferSize_ARGS
#define JAX_GPU_SOLVER_Orgqr_ARGS(Type, ...) \
gpusolverDnHandle_t handle, int m, int n, int k, Type *a, Type *tau, \
Type *workspace, int lwork, int *info
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr);
#undef JAX_GPU_SOLVER_Orgqr_ARGS
// Symmetric (Hermitian) eigendecomposition:
// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32)
// * QR algorithm: syevd/heevd
#define JAX_GPU_SOLVER_SyevjBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, SyevjBufferSize);
#undef JAX_GPU_SOLVER_SyevjBufferSize_ARGS
#define JAX_GPU_SOLVER_Syevj_ARGS(Type, Real) \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
int lwork, int *info, gpuSyevjInfo_t params
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevj);
#undef JAX_GPU_SOLVER_Syevj_ARGS
#define JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, SyevjBatchedBufferSize);
#undef JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS
#define JAX_GPU_SOLVER_SyevjBatched_ARGS(Type, Real) \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
int lwork, int *info, gpuSyevjInfo_t params, int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, SyevjBatched);
#undef JAX_GPU_SOLVER_SyevjBatched_ARGS
#define JAX_GPU_SOLVER_SyevdBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, SyevdBufferSize);
#undef JAX_GPU_SOLVER_SyevdBufferSize_ARGS
#define JAX_GPU_SOLVER_Syevd_ARGS(Type, Real) \
gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \
int lwork, int *info
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd);
#undef JAX_GPU_SOLVER_Syevd_ARGS
// Symmetric rank-k update: syrk
#define JAX_GPU_SOLVER_Syrk_ARGS(Type, ...) \
gpublasHandle_t handle, gpublasFillMode_t uplo, gpublasOperation_t trans, \
int n, int k, const Type *alpha, const Type *a, const Type *beta, \
Type *c
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk);
#undef JAX_GPU_SOLVER_Syrk_ARGS
#undef JAX_GPU_SOLVER_EXPAND_DEFINITION
} // namespace solver
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_GPU_SOLVER_INTERFACE_H_

View File

@ -29,9 +29,13 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/make_batch_pointers.h"
#include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/solver_interface.h"
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::JAX_GPU_NAMESPACE::SyevdAlgorithm);
namespace jax {
@ -39,7 +43,6 @@ namespace JAX_GPU_NAMESPACE {
namespace ffi = ::xla::ffi;
namespace {
template <typename T>
inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
int64_t size,
@ -53,22 +56,6 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
return static_cast<T*>(maybe_workspace.value());
}
template <typename T>
struct RealType {
using Type = T;
};
template <>
struct RealType<gpuComplex> {
using Type = float;
};
template <>
struct RealType<gpuDoubleComplex> {
using Type = double;
};
} // namespace
#define SOLVER_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::F32) { \
return impl<float>(__VA_ARGS__); \
@ -93,33 +80,6 @@ struct RealType<gpuDoubleComplex> {
// LU decomposition: getrf
namespace {
#define GETRF_KERNEL_IMPL(type, name) \
template <> \
struct GetrfKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
type* workspace, int lwork, int* ipiv, \
int* info) { \
return JAX_AS_STATUS( \
name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \
} \
}
template <typename T>
struct GetrfKernel;
GETRF_KERNEL_IMPL(float, gpusolverDnSgetrf);
GETRF_KERNEL_IMPL(double, gpusolverDnDgetrf);
GETRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgetrf);
GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf);
#undef GETRF_KERNEL_IMPL
template <typename T>
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@ -131,7 +91,7 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
GetrfKernel<T>::BufferSize(handle.get(), m, n));
solver::GetrfBufferSize<T>(handle.get(), m, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "getrf"));
@ -140,13 +100,13 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
int ipiv_step = std::min(m, n);
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(GetrfKernel<T>::Run(
FFI_RETURN_IF_ERROR_STATUS(solver::Getrf<T>(
handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data));
out_data += m * n;
ipiv_data += ipiv_step;
@ -155,23 +115,6 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
return ffi::Error::Success();
}
#define GETRF_BATCHED_KERNEL_IMPL(type, name) \
template <> \
struct GetrfBatchedKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \
int* ipiv, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \
} \
}
template <typename T>
struct GetrfBatchedKernel;
GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched);
GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched);
#undef GETRF_BATCHED_KERNEL_IMPL
template <typename T>
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
@ -188,15 +131,15 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch,
sizeof(T) * n * n);
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel<T>::Run(
FFI_RETURN_IF_ERROR_STATUS(solver::GetrfBatched(
handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch));
return ffi::Error::Success();
@ -228,7 +171,6 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in getrf", absl::FormatStreamed(dataType)));
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
ffi::Ffi::Bind()
@ -242,33 +184,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
// QR decomposition: geqrf
namespace {
#define GEQRF_KERNEL_IMPL(type, name) \
template <> \
struct GeqrfKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \
type* tau, type* workspace, int lwork, \
int* info) { \
return JAX_AS_STATUS( \
name(handle, m, n, a, m, tau, workspace, lwork, info)); \
} \
}
template <typename T>
struct GeqrfKernel;
GEQRF_KERNEL_IMPL(float, gpusolverDnSgeqrf);
GEQRF_KERNEL_IMPL(double, gpusolverDnDgeqrf);
GEQRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgeqrf);
GEQRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgeqrf);
#undef GEQRF_KERNEL_IMPL
template <typename T>
ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@ -279,7 +194,7 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
GeqrfKernel<T>::BufferSize(handle.get(), m, n));
solver::GeqrfBufferSize<T>(handle.get(), m, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "geqrf"));
@ -292,14 +207,14 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
auto out_data = static_cast<T*>(out->untyped_data());
auto tau_data = static_cast<T*>(tau->untyped_data());
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
int out_step = m * n;
int tau_step = std::min(m, n);
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel<T>::Run(
FFI_RETURN_IF_ERROR_STATUS(solver::Geqrf<T>(
handle.get(), m, n, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
tau_data += tau_step;
@ -307,23 +222,6 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
return ffi::Error::Success();
}
#define GEQRF_BATCHED_KERNEL_IMPL(type, name) \
template <> \
struct GeqrfBatchedKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \
type** tau, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
} \
}
template <typename T>
struct GeqrfBatchedKernel;
GEQRF_BATCHED_KERNEL_IMPL(float, gpublasSgeqrfBatched);
GEQRF_BATCHED_KERNEL_IMPL(double, gpublasDgeqrfBatched);
GEQRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgeqrfBatched);
GEQRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgeqrfBatched);
#undef GEQRF_BATCHED_KERNEL_IMPL
template <typename T>
ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@ -341,21 +239,21 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
auto out_data = out->untyped_data();
auto tau_data = tau->untyped_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch,
sizeof(T) * m * n);
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch,
sizeof(T) * std::min(m, n));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError()));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError());
// We ignore the output value of `info` because it is only used for shape
// checking.
int info;
FFI_RETURN_IF_ERROR_STATUS(GeqrfBatchedKernel<T>::Run(
FFI_RETURN_IF_ERROR_STATUS(solver::GeqrfBatched<T>(
handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch));
return ffi::Error::Success();
@ -385,7 +283,6 @@ ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in geqrf", absl::FormatStreamed(dataType)));
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
ffi::Ffi::Bind()
@ -398,34 +295,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
// Householder transformations: orgqr
namespace {
#define ORGQR_KERNEL_IMPL(type, name) \
template <> \
struct OrgqrKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
int n, int k) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \
/*tau=*/nullptr, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \
type* a, type* tau, type* workspace, int lwork, \
int* info) { \
return JAX_AS_STATUS( \
name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
} \
}
template <typename T>
struct OrgqrKernel;
ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr);
ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr);
ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr);
ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr);
#undef ORGQR_KERNEL_IMPL
template <typename T>
ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
@ -437,7 +306,7 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
OrgqrKernel<T>::BufferSize(handle.get(), m, n, k));
solver::OrgqrBufferSize<T>(handle.get(), m, n, k));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "orgqr"));
@ -450,13 +319,13 @@ ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
auto tau_data = static_cast<T*>(tau.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
int out_step = m * n;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel<T>::Run(
FFI_RETURN_IF_ERROR_STATUS(solver::Orgqr<T>(
handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
tau_data += k;
@ -492,7 +361,6 @@ ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in orgqr", absl::FormatStreamed(dataType)));
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
ffi::Ffi::Bind()
@ -510,98 +378,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
// dispatches dynamically to both syevd and syevj depending on the problem
// size and the algorithm selected by the user via the `algorithm` attribute.
namespace {
#define SYEVJ_KERNEL_IMPL(type, name) \
template <> \
struct SyevjKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, \
gpuSyevjInfo_t params) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
/*w=*/nullptr, &lwork, params))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
int n, type* a, RealType<type>::Type* w, \
type* workspace, int lwork, int* info, \
gpuSyevjInfo_t params) { \
return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \
lwork, info, params)); \
} \
}
template <typename T>
struct SyevjKernel;
SYEVJ_KERNEL_IMPL(float, gpusolverDnSsyevj);
SYEVJ_KERNEL_IMPL(double, gpusolverDnDsyevj);
SYEVJ_KERNEL_IMPL(gpuComplex, gpusolverDnCheevj);
SYEVJ_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevj);
#undef SYEVJ_KERNEL_IMPL
#define SYEVJ_BATCHED_KERNEL_IMPL(type, name) \
template <> \
struct SyevjBatchedKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n, \
gpuSyevjInfo_t params, int batch) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
/*w=*/nullptr, &lwork, params, batch))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
int n, type* a, RealType<type>::Type* w, \
type* workspace, int lwork, int* info, \
gpuSyevjInfo_t params, int batch) { \
return JAX_AS_STATUS(name(handle, jobz, uplo, n, a, n, w, workspace, \
lwork, info, params, batch)); \
} \
}
template <typename T>
struct SyevjBatchedKernel;
SYEVJ_BATCHED_KERNEL_IMPL(float, gpusolverDnSsyevjBatched);
SYEVJ_BATCHED_KERNEL_IMPL(double, gpusolverDnDsyevjBatched);
SYEVJ_BATCHED_KERNEL_IMPL(gpuComplex, gpusolverDnCheevjBatched);
SYEVJ_BATCHED_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevjBatched);
#undef SYEVJ_BATCHED_KERNEL_IMPL
#define SYEVD_KERNEL_IMPL(type, name) \
template <> \
struct SyevdKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, \
gpusolverFillMode_t uplo, int n) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \
/*w=*/nullptr, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, \
gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \
int n, type* a, RealType<type>::Type* w, \
type* workspace, int lwork, int* info) { \
return JAX_AS_STATUS( \
name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \
} \
}
template <typename T>
struct SyevdKernel;
SYEVD_KERNEL_IMPL(float, gpusolverDnSsyevd);
SYEVD_KERNEL_IMPL(double, gpusolverDnDsyevd);
SYEVD_KERNEL_IMPL(gpuComplex, gpusolverDnCheevd);
SYEVD_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZheevd);
#undef SYEVD_KERNEL_IMPL
template <typename T>
ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm,
@ -618,49 +394,48 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto w_data = static_cast<RealType<T>::Type*>(w->untyped_data());
auto w_data = static_cast<solver::RealType<T>::value*>(w->untyped_data());
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
if (algorithm == SyevdAlgorithm::kJacobi ||
(algorithm == SyevdAlgorithm::kDefault && size <= 32)) {
gpuSyevjInfo_t params;
FFI_RETURN_IF_ERROR_STATUS(
JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(&params)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(&params));
std::unique_ptr<gpuSyevjInfo, void (*)(gpuSyevjInfo_t)> params_cleanup(
params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); });
if (batch == 1) {
FFI_ASSIGN_OR_RETURN(int lwork, SyevjKernel<T>::BufferSize(
FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize<T>(
handle.get(), jobz, uplo, n, params));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevj"));
FFI_RETURN_IF_ERROR_STATUS(
SyevjKernel<T>::Run(handle.get(), jobz, uplo, n, out_data, w_data,
workspace, lwork, info_data, params));
FFI_RETURN_IF_ERROR_STATUS(solver::Syevj<T>(handle.get(), jobz, uplo, n,
out_data, w_data, workspace,
lwork, info_data, params));
} else {
FFI_ASSIGN_OR_RETURN(
int lwork, SyevjBatchedKernel<T>::BufferSize(handle.get(), jobz, uplo,
int lwork, solver::SyevjBatchedBufferSize<T>(handle.get(), jobz, uplo,
n, params, batch));
FFI_ASSIGN_OR_RETURN(
auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevj_batched"));
FFI_RETURN_IF_ERROR_STATUS(SyevjBatchedKernel<T>::Run(
handle.get(), jobz, uplo, n, out_data, w_data, workspace, lwork,
info_data, params, batch));
FFI_RETURN_IF_ERROR_STATUS(
solver::SyevjBatched<T>(handle.get(), jobz, uplo, n, out_data, w_data,
workspace, lwork, info_data, params, batch));
}
} else {
FFI_ASSIGN_OR_RETURN(
int lwork, SyevdKernel<T>::BufferSize(handle.get(), jobz, uplo, n));
int lwork, solver::SyevdBufferSize<T>(handle.get(), jobz, uplo, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "syevd"));
int out_step = n * n;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(
SyevdKernel<T>::Run(handle.get(), jobz, uplo, n, out_data, w_data,
workspace, lwork, info_data));
FFI_RETURN_IF_ERROR_STATUS(solver::Syevd<T>(handle.get(), jobz, uplo, n,
out_data, w_data, workspace,
lwork, info_data));
out_data += out_step;
w_data += n;
++info_data;
@ -695,7 +470,6 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in syevd", absl::FormatStreamed(dataType)));
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch,
ffi::Ffi::Bind()
@ -709,110 +483,83 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch,
.Ret<ffi::Buffer<ffi::S32>>() // info
);
#define SYRK_KERNEL_IMPL(type, fn) \
template <> \
struct SyrkKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, std::int64_t n, \
std::int64_t k, bool transpose, \
const type* alpha, const type* beta, \
const type* a_matrix, type* c_matrix) { \
gpublasOperation_t op = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; \
gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; \
int lda = transpose ? n : k; \
return JAX_AS_STATUS(fn(handle, uplo, op, n, k, \
alpha, a_matrix, lda, beta, \
c_matrix, n)); \
} \
}
// Symmetric rank-k update: syrk
template <typename T>
struct SyrkKernel;
SYRK_KERNEL_IMPL(float, gpublasSsyrk);
SYRK_KERNEL_IMPL(double, gpublasDsyrk);
SYRK_KERNEL_IMPL(gpublasComplex, gpublasCsyrk);
SYRK_KERNEL_IMPL(gpublasDoubleComplex, gpublasZsyrk);
#undef SYRK_KERNEL_IMPL
template <typename T>
ffi::Error SyrkImpl(gpuStream_t stream,
ffi::AnyBuffer a_matrix,
ffi::AnyBuffer c_matrix,
bool transpose,
ffi::AnyBuffer alpha,
ffi::AnyBuffer beta,
ffi::Result<ffi::AnyBuffer> c_matrix_out) {
ffi::Error SyrkImpl(gpuStream_t stream, bool transpose, ffi::AnyBuffer a,
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha,
ffi::AnyBuffer beta, ffi::Result<ffi::AnyBuffer> c_out) {
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a_matrix.dimensions()));
FFI_ASSIGN_OR_RETURN((auto [batch_c, rows_c, cols_c]),
SplitBatch2D(c_matrix.dimensions()));
FFI_ASSIGN_OR_RETURN((auto [batch_out, rows_out, cols_out]),
SplitBatch2D(c_matrix_out->dimensions()));
if (batch != batch_c || batch != batch_out) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"a_matrix, c_matrix and c_matrix_out must have the same "
"batch size.");
SplitBatch2D(a.dimensions()));
if (alpha.element_count() != 1 || beta.element_count() != 1) {
return ffi::Error::InvalidArgument(
"The alpha and beta inputs to syrk must be scalars");
}
int n = transpose ? cols : rows;
int k = transpose ? rows : cols;
auto size = transpose ? cols : rows;
FFI_RETURN_IF_ERROR(
CheckShape(c_matrix_out->dimensions().last(2), {n, n}, "out", "Syrk"));
CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk"));
FFI_RETURN_IF_ERROR(
CheckShape(c_matrix.dimensions().last(2), {n, n}, "C", "Syrk"));
CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk"));
const T* a_data = static_cast<const T*>(a_matrix.untyped_data());
T* c_data = static_cast<T*>(c_matrix.untyped_data());
T* c_out_data = static_cast<T*>(c_matrix_out->untyped_data());
FFI_ASSIGN_OR_RETURN(auto n,
MaybeCastNoOverflow<int>(transpose ? cols : rows));
FFI_ASSIGN_OR_RETURN(auto k,
MaybeCastNoOverflow<int>(transpose ? rows : cols));
gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER;
gpublasOperation_t trans = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T;
const T* a_data = static_cast<const T*>(a.untyped_data());
T* c_data = static_cast<T*>(c_in.untyped_data());
T* c_out_data = static_cast<T*>(c_out->untyped_data());
// with alpha or beta provided as device_pointers, cublas<T>syrk will SIGSEGV
T host_alpha;
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
&host_alpha, alpha.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_alpha, alpha.untyped_data(),
sizeof(T), gpuMemcpyDeviceToHost,
stream));
T host_beta;
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
&host_beta, beta.untyped_data(), sizeof(T), gpuMemcpyDeviceToHost,
stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_beta, beta.untyped_data(),
sizeof(T), gpuMemcpyDeviceToHost,
stream));
if (c_data != c_out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
c_out_data, c_data, c_matrix.size_bytes(), gpuMemcpyDeviceToDevice,
stream)));
JAX_FFI_RETURN_IF_GPU_ERROR(
gpuMemcpyAsync(c_out_data, c_data, c_in.size_bytes(),
gpuMemcpyDeviceToDevice, stream));
}
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
for (int i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(SyrkKernel<T>::Run(
handle.get(), n, k, transpose, &host_alpha, &host_beta,
a_data + i * k * n, c_out_data + i * n * n));
FFI_RETURN_IF_ERROR_STATUS(solver::Syrk<T>(handle.get(), uplo, trans, n, k,
&host_alpha, a_data, &host_beta,
c_out_data));
a_data += k * n;
c_out_data += n * n;
}
return ffi::Error::Success();
}
ffi::Error SyrkDispatch(
gpuStream_t stream,
ffi::AnyBuffer a_matrix,
ffi::AnyBuffer c_matrix,
bool transpose,
ffi::AnyBuffer alpha,
ffi::AnyBuffer beta,
ffi::Result<ffi::AnyBuffer> c_matrix_out) {
auto dataType = a_matrix.element_type();
SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, a_matrix, c_matrix, transpose,
alpha, beta, c_matrix_out);
return ffi::Error::InvalidArgument("Unsupported element type for Syrk");
ffi::Error SyrkDispatch(gpuStream_t stream, bool transpose, ffi::AnyBuffer a,
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha,
ffi::AnyBuffer beta,
ffi::Result<ffi::AnyBuffer> c_out) {
auto dataType = a.element_type();
SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, a, c_in, alpha, beta,
c_out);
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in syrk", absl::FormatStreamed(dataType)));
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Arg<ffi::AnyBuffer>() // a_matrix
.Arg<ffi::AnyBuffer>() // c_matrix
.Attr<bool>("transpose") // transpose
.Arg<ffi::AnyBuffer>() // alpha
.Arg<ffi::AnyBuffer>() // beta
.Ret<ffi::AnyBuffer>()); // c_matrix_out
.Arg<ffi::AnyBuffer>() // a
.Arg<ffi::AnyBuffer>() // c_in
.Arg<ffi::AnyBuffer>() // alpha
.Arg<ffi::AnyBuffer>() // beta
.Ret<ffi::AnyBuffer>() // c_out
);
#undef SOLVER_DISPATCH_IMPL
#undef SOLVER_BLAS_DISPATCH_IMPL

View File

@ -168,6 +168,21 @@ cc_library(
],
)
cc_library(
name = "hipsolver_interface",
srcs = ["//jaxlib/gpu:solver_interface.cc"],
hdrs = ["//jaxlib/gpu:solver_interface.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_vendor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:hipsolver",
],
)
cc_library(
name = "hipsolver_kernels_ffi",
srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"],
@ -178,6 +193,7 @@ cc_library(
":hip_make_batch_pointers",
":hip_solver_handle_pool",
":hip_vendor",
":hipsolver_interface",
"//jaxlib:ffi_helpers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",