Port GPU kernels for SVD to the FFI.

Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges:

1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications.

2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel.

Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from https://github.com/jax-ml/jax/issues/23413.

PiperOrigin-RevId: 676839182
This commit is contained in:
Dan Foreman-Mackey 2024-09-20 07:34:05 -07:00 committed by jax authors
parent 7f3a90c63b
commit afaa3bf43c
6 changed files with 551 additions and 21 deletions

View File

@ -20,8 +20,8 @@ limitations under the License.
#include "nanobind/nanobind.h" #include "nanobind/nanobind.h"
#include "nanobind/stl/pair.h" #include "nanobind/stl/pair.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_handle_pool.h"
#include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/solver_kernels.h"
@ -481,6 +481,11 @@ nb::dict Registrations() {
dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi);
dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi);
dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi);
dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi);
#ifdef JAX_GPU_CUDA
dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi);
#endif // JAX_GPU_CUDA
return dict; return dict;
} }

View File

@ -232,6 +232,91 @@ JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk);
JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk);
#undef JAX_GPU_DEFINE_SYRK #undef JAX_GPU_DEFINE_SYRK
// Singular Value Decomposition: gesvd
#define JAX_GPU_DEFINE_GESVD(Type, Name) \
template <> \
absl::StatusOr<int> GesvdBufferSize<Type>(gpusolverDnHandle_t handle, \
signed char job, int m, int n) { \
int lwork; \
JAX_RETURN_IF_ERROR( \
JAX_AS_STATUS(Name##_bufferSize(handle, job, job, m, n, &lwork))); \
return lwork; \
} \
\
template <> \
absl::Status Gesvd<Type>(gpusolverDnHandle_t handle, signed char job, int m, \
int n, Type *a, RealType<Type>::value *s, Type *u, \
Type *vt, Type *workspace, int lwork, int *info) { \
return JAX_AS_STATUS(Name(handle, job, job, m, n, a, m, s, u, m, vt, n, \
workspace, lwork, /*rwork=*/nullptr, info)); \
}
JAX_GPU_DEFINE_GESVD(float, gpusolverDnSgesvd);
JAX_GPU_DEFINE_GESVD(double, gpusolverDnDgesvd);
JAX_GPU_DEFINE_GESVD(gpuComplex, gpusolverDnCgesvd);
JAX_GPU_DEFINE_GESVD(gpuDoubleComplex, gpusolverDnZgesvd);
#undef JAX_GPU_DEFINE_GESVD
#ifdef JAX_GPU_CUDA
#define JAX_GPU_DEFINE_GESVDJ(Type, Name) \
template <> \
absl::StatusOr<int> GesvdjBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \
int n, gpuGesvdjInfo_t params) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \
handle, job, econ, m, n, /*a=*/nullptr, /*lda=*/m, /*s=*/nullptr, \
/*u=*/nullptr, /*ldu=*/m, /*v=*/nullptr, /*ldv=*/n, &lwork, params))); \
return lwork; \
} \
\
template <> \
absl::Status Gesvdj<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \
int n, Type *a, RealType<Type>::value *s, Type *u, Type *v, \
Type *workspace, int lwork, int *info, gpuGesvdjInfo_t params) { \
return JAX_AS_STATUS(Name(handle, job, econ, m, n, a, m, s, u, m, v, n, \
workspace, lwork, info, params)); \
}
JAX_GPU_DEFINE_GESVDJ(float, gpusolverDnSgesvdj);
JAX_GPU_DEFINE_GESVDJ(double, gpusolverDnDgesvdj);
JAX_GPU_DEFINE_GESVDJ(gpuComplex, gpusolverDnCgesvdj);
JAX_GPU_DEFINE_GESVDJ(gpuDoubleComplex, gpusolverDnZgesvdj);
#undef JAX_GPU_DEFINE_GESVDJ
#define JAX_GPU_DEFINE_GESVDJ_BATCHED(Type, Name) \
template <> \
absl::StatusOr<int> GesvdjBatchedBufferSize<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \
gpuGesvdjInfo_t params, int batch) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
Name##_bufferSize(handle, job, m, n, /*a=*/nullptr, /*lda=*/m, \
/*s=*/nullptr, /*u=*/nullptr, /*ldu=*/m, \
/*v=*/nullptr, /*ldv=*/n, &lwork, params, batch))); \
return lwork; \
} \
\
template <> \
absl::Status GesvdjBatched<Type>( \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \
Type *a, RealType<Type>::value *s, Type *u, Type *v, Type *workspace, \
int lwork, int *info, gpuGesvdjInfo_t params, int batch) { \
return JAX_AS_STATUS(Name(handle, job, m, n, a, m, s, u, m, v, n, \
workspace, lwork, info, params, batch)); \
}
JAX_GPU_DEFINE_GESVDJ_BATCHED(float, gpusolverDnSgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(double, gpusolverDnDgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched);
#undef JAX_GPU_DEFINE_GESVDJ_BATCHED
#endif // JAX_GPU_CUDA
} // namespace solver } // namespace solver
} // namespace JAX_GPU_NAMESPACE } // namespace JAX_GPU_NAMESPACE
} // namespace jax } // namespace jax

View File

@ -165,6 +165,49 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd);
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk);
#undef JAX_GPU_SOLVER_Syrk_ARGS #undef JAX_GPU_SOLVER_Syrk_ARGS
// Singular Value Decomposition: gesvd
#define JAX_GPU_SOLVER_GesvdBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, signed char job, int m, int n
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdBufferSize);
#undef JAX_GPU_SOLVER_GesvdBufferSize_ARGS
#define JAX_GPU_SOLVER_Gesvd_ARGS(Type, Real) \
gpusolverDnHandle_t handle, signed char job, int m, int n, Type *a, Real *s, \
Type *u, Type *vt, Type *workspace, int lwork, int *info
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvd);
#undef JAX_GPU_SOLVER_Gesvd_ARGS
#ifdef JAX_GPU_CUDA
#define JAX_GPU_SOLVER_GesvdjBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \
gesvdjInfo_t params
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBufferSize);
#undef JAX_GPU_SOLVER_GesvdjBufferSize_ARGS
#define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \
Type *a, Real *s, Type *u, Type *v, Type *workspace, \
int lwork, int *info, gesvdjInfo_t params
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj);
#undef JAX_GPU_SOLVER_Gesvdj_ARGS
#define JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS(Type, ...) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \
gpuGesvdjInfo_t params, int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBatchedBufferSize);
#undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS
#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \
gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \
Real *s, Type *u, Type *v, Type *workspace, int lwork, \
int *info, gpuGesvdjInfo_t params, int batch
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched);
#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS
#endif // JAX_GPU_CUDA
#undef JAX_GPU_SOLVER_EXPAND_DEFINITION #undef JAX_GPU_SOLVER_EXPAND_DEFINITION
} // namespace solver } // namespace solver

View File

@ -33,6 +33,14 @@ limitations under the License.
#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h" #include "xla/ffi/api/ffi.h"
#if JAX_GPU_64_BIT
#include <cstddef>
#endif
#ifdef JAX_GPU_CUDA
#include <limits>
#endif
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
@ -56,26 +64,32 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
return static_cast<T*>(maybe_workspace.value()); return static_cast<T*>(maybe_workspace.value());
} }
#define SOLVER_DISPATCH_IMPL(impl, ...) \ #define SOLVER_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::F32) { \ switch (dataType) { \
return impl<float>(__VA_ARGS__); \ case ffi::F32: \
} else if (dataType == ffi::F64) { \ return impl<float>(__VA_ARGS__); \
return impl<double>(__VA_ARGS__); \ case ffi::F64: \
} else if (dataType == ffi::C64) { \ return impl<double>(__VA_ARGS__); \
return impl<gpuComplex>(__VA_ARGS__); \ case ffi::C64: \
} else if (dataType == ffi::C128) { \ return impl<gpuComplex>(__VA_ARGS__); \
return impl<gpuDoubleComplex>(__VA_ARGS__); \ case ffi::C128: \
return impl<gpuDoubleComplex>(__VA_ARGS__); \
default: \
break; \
} }
#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ #define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::F32) { \ switch (dataType) { \
return impl<float>(__VA_ARGS__); \ case ffi::F32: \
} else if (dataType == ffi::F64) { \ return impl<float>(__VA_ARGS__); \
return impl<double>(__VA_ARGS__); \ case ffi::F64: \
} else if (dataType == ffi::C64) { \ return impl<double>(__VA_ARGS__); \
return impl<gpublasComplex>(__VA_ARGS__); \ case ffi::C64: \
} else if (dataType == ffi::C128) { \ return impl<gpublasComplex>(__VA_ARGS__); \
return impl<gpublasDoubleComplex>(__VA_ARGS__); \ case ffi::C128: \
return impl<gpublasDoubleComplex>(__VA_ARGS__); \
default: \
break; \
} }
// LU decomposition: getrf // LU decomposition: getrf
@ -445,8 +459,8 @@ ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream,
} }
ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
SyevdAlgorithm algorithm, bool lower, ffi::AnyBuffer a, SyevdAlgorithm algorithm, bool lower,
ffi::Result<ffi::AnyBuffer> out, ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> w, ffi::Result<ffi::AnyBuffer> w,
ffi::Result<ffi::Buffer<ffi::S32>> info) { ffi::Result<ffi::Buffer<ffi::S32>> info) {
auto dataType = a.element_type(); auto dataType = a.element_type();
@ -561,6 +575,345 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch,
.Ret<ffi::AnyBuffer>() // c_out .Ret<ffi::AnyBuffer>() // c_out
); );
// Singular Value Decomposition: gesvd
#if JAX_GPU_64_BIT
ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream,
ffi::ScratchAllocator& scratch, bool full_matrices,
bool compute_uv, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> s,
ffi::Result<ffi::AnyBuffer> u,
ffi::Result<ffi::AnyBuffer> vt,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
auto dataType = a.element_type();
gpuDataType aType, sType;
switch (dataType) {
case ffi::F32:
aType = GPU_R_32F;
sType = GPU_R_32F;
break;
case ffi::F64:
aType = GPU_R_64F;
sType = GPU_R_64F;
break;
case ffi::C64:
aType = GPU_C_32F;
sType = GPU_R_32F;
break;
case ffi::C128:
aType = GPU_C_64F;
sType = GPU_R_64F;
break;
default:
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType)));
}
gpusolverDnParams_t params;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(&params));
std::unique_ptr<gpusolverDnParams, void (*)(gpusolverDnParams_t)>
params_cleanup(
params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); });
size_t workspaceInBytesOnDevice, workspaceInBytesOnHost;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd_bufferSize(
handle.get(), params, job, job, m, n, aType, /*a=*/nullptr, m, sType,
/*s=*/nullptr, aType, /*u=*/nullptr, m, aType, /*vt=*/nullptr, n, aType,
&workspaceInBytesOnDevice, &workspaceInBytesOnHost));
auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice);
if (!maybe_workspace.has_value()) {
return ffi::Error(ffi::ErrorCode::kResourceExhausted,
"Unable to allocate device workspace for gesvd");
}
auto workspaceOnDevice = maybe_workspace.value();
auto workspaceOnHost =
std::unique_ptr<char[]>(new char[workspaceInBytesOnHost]);
const char* a_data = static_cast<const char*>(a.untyped_data());
char* out_data = static_cast<char*>(out->untyped_data());
char* s_data = static_cast<char*>(s->untyped_data());
char* u_data = static_cast<char*>(u->untyped_data());
char* vt_data = static_cast<char*>(vt->untyped_data());
int* info_data = info->typed_data();
if (a_data != out_data) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
size_t out_step = m * n * ffi::ByteWidth(dataType);
size_t s_step = n * ffi::ByteWidth(ffi::ToReal(dataType));
size_t u_step = 0;
size_t vt_step = 0;
if (compute_uv) {
u_step = m * (full_matrices ? m : n) * ffi::ByteWidth(dataType);
vt_step = n * n * ffi::ByteWidth(dataType);
}
for (auto i = 0; i < batch; ++i) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd(
handle.get(), params, job, job, m, n, aType, out_data, m, sType, s_data,
aType, u_data, m, aType, vt_data, n, aType, workspaceOnDevice,
workspaceInBytesOnDevice, workspaceOnHost.get(), workspaceInBytesOnHost,
info_data));
out_data += out_step;
s_data += s_step;
u_data += u_step;
vt_data += vt_step;
++info_data;
}
return ffi::Error::Success();
}
#else
template <typename T>
ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
bool full_matrices, bool compute_uv, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> s,
ffi::Result<ffi::AnyBuffer> u,
ffi::Result<ffi::AnyBuffer> vt,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
FFI_ASSIGN_OR_RETURN(int lwork,
solver::GesvdBufferSize<T>(handle.get(), job, m, n));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "gesvd"));
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto s_data = static_cast<solver::RealType<T>::value*>(s->untyped_data());
auto u_data = compute_uv ? static_cast<T*>(u->untyped_data()) : nullptr;
auto vt_data = compute_uv ? static_cast<T*>(vt->untyped_data()) : nullptr;
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}
int out_step = m * n;
int u_step = compute_uv ? m * (full_matrices ? m : n) : 0;
int vt_step = compute_uv ? n * n : 0;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(
solver::Gesvd<T>(handle.get(), job, m, n, out_data, s_data, u_data,
vt_data, workspace, lwork, info_data));
out_data += out_step;
s_data += n; // n is always less than m because of the logic in dispatch.
u_data += u_step;
vt_data += vt_step;
++info_data;
}
return ffi::Error::Success();
}
#endif // JAX_GPU_64_BIT
ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
bool full_matrices, bool compute_uv, bool transposed,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> s,
ffi::Result<ffi::AnyBuffer> u,
ffi::Result<ffi::AnyBuffer> vt,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
auto dataType = a.element_type();
if (out->element_type() != dataType ||
s->element_type() != ffi::ToReal(dataType) ||
u->element_type() != dataType || vt->element_type() != dataType) {
return ffi::Error::InvalidArgument(
"The inputs and outputs to gesvd must have the same element type");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a.dimensions()));
int64_t m = transposed ? cols : rows;
int64_t n = transposed ? rows : cols;
if (n > m) {
return ffi::Error::InvalidArgument(
"The GPU implementation of gesvd requires that the input matrix be m x "
"n with m >= n");
}
FFI_RETURN_IF_ERROR(
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvd"));
FFI_RETURN_IF_ERROR(CheckShape(s->dimensions(), {batch, n}, "s", "gesvd"));
if (compute_uv) {
if (full_matrices) {
FFI_RETURN_IF_ERROR(
CheckShape(u->dimensions(), {batch, m, m}, "u", "gesvd"));
} else {
if (transposed) {
FFI_RETURN_IF_ERROR(
CheckShape(u->dimensions(), {batch, n, m}, "u", "gesvd"));
} else {
FFI_RETURN_IF_ERROR(
CheckShape(u->dimensions(), {batch, m, n}, "u", "gesvd"));
}
}
FFI_RETURN_IF_ERROR(
CheckShape(vt->dimensions(), {batch, n, n}, "vt", "gesvd"));
}
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd"));
#if JAX_GPU_64_BIT
return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a,
out, s, u, vt, info);
#else
SOLVER_DISPATCH_IMPL(GesvdImpl, batch, m, n, stream, scratch, full_matrices,
compute_uv, a, out, s, u, vt, info);
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType)));
#endif
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdFfi, GesvdDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Attr<bool>("full_matrices")
.Attr<bool>("compute_uv")
.Attr<bool>("transposed")
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::AnyBuffer>() // s
.Ret<ffi::AnyBuffer>() // u
.Ret<ffi::AnyBuffer>() // vt
.Ret<ffi::Buffer<ffi::S32>>() // info
);
#ifdef JAX_GPU_CUDA
template <typename T>
ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
bool full_matrices, bool compute_uv, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> s,
ffi::Result<ffi::AnyBuffer> u,
ffi::Result<ffi::AnyBuffer> v,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
gpusolverEigMode_t job =
compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR;
int econ = full_matrices ? 0 : 1;
gpuGesvdjInfo_t params;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateGesvdjInfo(&params));
std::unique_ptr<gpuGesvdjInfo, void (*)(gpuGesvdjInfo_t)> params_cleanup(
params, [](gpuGesvdjInfo_t p) { gpusolverDnDestroyGesvdjInfo(p); });
auto a_data = static_cast<T*>(a.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
auto s_data = static_cast<solver::RealType<T>::value*>(s->untyped_data());
auto u_data = static_cast<T*>(u->untyped_data());
auto v_data = static_cast<T*>(v->untyped_data());
auto info_data = info->typed_data();
if (a_data != out_data) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}
if (batch <= 1 || batch > std::numeric_limits<int>::max() || m > 32 ||
n > 32 || econ) {
FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBufferSize<T>(
handle.get(), job, econ, m, n, params));
FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "gesvdj"));
int k = std::min(m, n);
int out_step = m * n;
int u_step = m * (full_matrices ? m : k);
int v_step = n * (full_matrices ? n : k);
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(solver::Gesvdj<T>(
handle.get(), job, econ, m, n, out_data, s_data, u_data, v_data,
workspace, lwork, info_data, params));
out_data += out_step;
s_data += k;
u_data += u_step;
v_data += v_step;
++info_data;
}
} else {
FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBatchedBufferSize<T>(
handle.get(), job, m, n, params,
static_cast<int>(batch)));
FFI_ASSIGN_OR_RETURN(
auto workspace, AllocateWorkspace<T>(scratch, lwork, "gesvdj_batched"));
FFI_RETURN_IF_ERROR_STATUS(solver::GesvdjBatched<T>(
handle.get(), job, m, n, out_data, s_data, u_data, v_data, workspace,
lwork, info_data, params, static_cast<int>(batch)));
}
return ffi::Error::Success();
}
ffi::Error GesvdjDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
bool full_matrices, bool compute_uv, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::AnyBuffer> s,
ffi::Result<ffi::AnyBuffer> u,
ffi::Result<ffi::AnyBuffer> v,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
auto dataType = a.element_type();
if (out->element_type() != dataType ||
s->element_type() != ffi::ToReal(dataType) ||
u->element_type() != dataType || v->element_type() != dataType) {
return ffi::Error::InvalidArgument(
"The inputs and outputs to gesvdj must have the same element type");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a.dimensions()));
int64_t size = std::min(rows, cols);
FFI_RETURN_IF_ERROR(
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvdj"));
FFI_RETURN_IF_ERROR(
CheckShape(s->dimensions(), {batch, size}, "s", "gesvdj"));
// U and V must always be allocated even if compute_uv is false.
if (full_matrices) {
FFI_RETURN_IF_ERROR(
CheckShape(u->dimensions(), {batch, rows, rows}, "u", "gesvdj"));
FFI_RETURN_IF_ERROR(
CheckShape(v->dimensions(), {batch, cols, cols}, "v", "gesvdj"));
} else {
FFI_RETURN_IF_ERROR(
CheckShape(u->dimensions(), {batch, rows, size}, "u", "gesvdj"));
FFI_RETURN_IF_ERROR(
CheckShape(v->dimensions(), {batch, cols, size}, "v", "gesvdj"));
}
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvdj"));
SOLVER_DISPATCH_IMPL(GesvdjImpl, batch, rows, cols, stream, scratch,
full_matrices, compute_uv, a, out, s, u, v, info);
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in gesvdj", absl::FormatStreamed(dataType)));
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Attr<bool>("full_matrices")
.Attr<bool>("compute_uv")
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::AnyBuffer>() // s
.Ret<ffi::AnyBuffer>() // u
.Ret<ffi::AnyBuffer>() // v
.Ret<ffi::Buffer<ffi::S32>>() // info
);
#endif // JAX_GPU_CUDA
#undef SOLVER_DISPATCH_IMPL #undef SOLVER_DISPATCH_IMPL
#undef SOLVER_BLAS_DISPATCH_IMPL #undef SOLVER_BLAS_DISPATCH_IMPL

View File

@ -35,6 +35,11 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi);
#ifdef JAX_GPU_CUDA
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi);
#endif // JAX_GPU_CUDA
} // namespace JAX_GPU_NAMESPACE } // namespace JAX_GPU_NAMESPACE
} // namespace jax } // namespace jax

View File

@ -78,6 +78,8 @@ typedef cusolverStatus_t gpusolverStatus_t;
typedef cusolverEigMode_t gpusolverEigMode_t; typedef cusolverEigMode_t gpusolverEigMode_t;
typedef syevjInfo gpuSyevjInfo; typedef syevjInfo gpuSyevjInfo;
typedef syevjInfo_t gpuSyevjInfo_t; typedef syevjInfo_t gpuSyevjInfo_t;
typedef gesvdjInfo gpuGesvdjInfo;
typedef gesvdjInfo_t gpuGesvdjInfo_t;
typedef cusparseIndexType_t gpusparseIndexType_t; typedef cusparseIndexType_t gpusparseIndexType_t;
typedef cusparseHandle_t gpusparseHandle_t; typedef cusparseHandle_t gpusparseHandle_t;
typedef cusparseOperation_t gpusparseOperation_t; typedef cusparseOperation_t gpusparseOperation_t;
@ -120,6 +122,8 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpusolverDnSetStream cusolverDnSetStream #define gpusolverDnSetStream cusolverDnSetStream
#define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo #define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo
#define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo #define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo
#define gpusolverDnCreateGesvdjInfo cusolverDnCreateGesvdjInfo
#define gpusolverDnDestroyGesvdjInfo cusolverDnDestroyGesvdjInfo
#define gpusolverDnSgeqrf cusolverDnSgeqrf #define gpusolverDnSgeqrf cusolverDnSgeqrf
#define gpusolverDnDgeqrf cusolverDnDgeqrf #define gpusolverDnDgeqrf cusolverDnDgeqrf
#define gpusolverDnCgeqrf cusolverDnCgeqrf #define gpusolverDnCgeqrf cusolverDnCgeqrf
@ -184,6 +188,22 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
cusolverDnCgesvd_bufferSize(h, m, n, lwork) cusolverDnCgesvd_bufferSize(h, m, n, lwork)
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ #define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
cusolverDnZgesvd_bufferSize(h, m, n, lwork) cusolverDnZgesvd_bufferSize(h, m, n, lwork)
#define gpusolverDnSgesvdj cusolverDnSgesvdj
#define gpusolverDnDgesvdj cusolverDnDgesvdj
#define gpusolverDnCgesvdj cusolverDnCgesvdj
#define gpusolverDnZgesvdj cusolverDnZgesvdj
#define gpusolverDnSgesvdj_bufferSize cusolverDnSgesvdj_bufferSize
#define gpusolverDnDgesvdj_bufferSize cusolverDnDgesvdj_bufferSize
#define gpusolverDnCgesvdj_bufferSize cusolverDnCgesvdj_bufferSize
#define gpusolverDnZgesvdj_bufferSize cusolverDnZgesvdj_bufferSize
#define gpusolverDnSgesvdjBatched cusolverDnSgesvdjBatched
#define gpusolverDnDgesvdjBatched cusolverDnDgesvdjBatched
#define gpusolverDnCgesvdjBatched cusolverDnCgesvdjBatched
#define gpusolverDnZgesvdjBatched cusolverDnZgesvdjBatched
#define gpusolverDnSgesvdjBatched_bufferSize cusolverDnSgesvdjBatched_bufferSize
#define gpusolverDnDgesvdjBatched_bufferSize cusolverDnDgesvdjBatched_bufferSize
#define gpusolverDnCgesvdjBatched_bufferSize cusolverDnCgesvdjBatched_bufferSize
#define gpusolverDnZgesvdjBatched_bufferSize cusolverDnZgesvdjBatched_bufferSize
#define gpusolverDnSsytrd_bufferSize cusolverDnSsytrd_bufferSize #define gpusolverDnSsytrd_bufferSize cusolverDnSsytrd_bufferSize
#define gpusolverDnDsytrd_bufferSize cusolverDnDsytrd_bufferSize #define gpusolverDnDsytrd_bufferSize cusolverDnDsytrd_bufferSize
#define gpusolverDnChetrd_bufferSize cusolverDnChetrd_bufferSize #define gpusolverDnChetrd_bufferSize cusolverDnChetrd_bufferSize
@ -196,6 +216,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER
#define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER #define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER
#define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR #define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR
#define GPUSOLVER_EIG_MODE_NOVECTOR CUSOLVER_EIG_MODE_NOVECTOR
#define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS #define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS
#define GPUBLAS_OP_N CUBLAS_OP_N #define GPUBLAS_OP_N CUBLAS_OP_N
@ -311,6 +332,22 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuGetDeviceProperties cudaGetDeviceProperties #define gpuGetDeviceProperties cudaGetDeviceProperties
#define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel #define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel
#define JAX_GPU_64_BIT 1
#define GPU_R_32F CUDA_R_32F
#define GPU_R_64F CUDA_R_64F
#define GPU_C_32F CUDA_C_32F
#define GPU_C_64F CUDA_C_64F
typedef cudaDataType gpuDataType;
typedef cusolverDnParams gpusolverDnParams;
typedef cusolverDnParams_t gpusolverDnParams_t;
#define gpusolverDnCreateParams cusolverDnCreateParams
#define gpusolverDnDestroyParams cusolverDnDestroyParams
#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize
#define gpusolverDnXgesvd cusolverDnXgesvd
namespace jax::JAX_GPU_NAMESPACE { namespace jax::JAX_GPU_NAMESPACE {
namespace { namespace {
constexpr uint32_t kNumThreadsPerWarp = 32; constexpr uint32_t kNumThreadsPerWarp = 32;
@ -331,6 +368,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32;
#define JAX_GPU_PREFIX "hip" #define JAX_GPU_PREFIX "hip"
#define JAX_GPU_HAVE_SPARSE 1 #define JAX_GPU_HAVE_SPARSE 1
#define JAX_GPU_64_BIT 0
#define JAX_GPU_HAVE_FP8 0 #define JAX_GPU_HAVE_FP8 0
typedef hipFloatComplex gpuComplex; typedef hipFloatComplex gpuComplex;
@ -472,6 +510,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER
#define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER #define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER
#define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR #define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR
#define GPUSOLVER_EIG_MODE_NOVECTOR HIPSOLVER_EIG_MODE_NOVECTOR
#define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS #define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS
#define GPUBLAS_OP_N HIPBLAS_OP_N #define GPUBLAS_OP_N HIPBLAS_OP_N