Implement SVD algorithm based on QR for CPU targets

In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).

This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.

Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
This commit is contained in:
Jan Naumann 2024-11-22 11:32:38 +01:00 committed by Jan Luca Naumann
parent a75bea51de
commit e03fe3a06d
7 changed files with 265 additions and 4 deletions

View File

@ -2069,10 +2069,13 @@ def _svd_cpu_gpu_lowering(
compute_uv=compute_uv,
)
if target_name_prefix == "cpu":
if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT:
if algorithm is None or algorithm == SvdAlgorithm.DEFAULT:
target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype)
elif algorithm == SvdAlgorithm.QR:
target_name = lapack.prepare_lapack_call("gesvd_ffi", operand_aval.dtype)
else:
raise NotImplementedError(
"The SVD algorithm parameter is not implemented on CPU.")
target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype)
"The SVD Jacobi algorithm is not implemented on CPU.")
mode = _svd_computation_attr(compute_uv, full_matrices)
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
if compute_uv:

View File

@ -271,7 +271,8 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
only the singular values ``s``.
overwrite_a: unused by JAX
check_finite: unused by JAX
lapack_driver: unused by JAX
lapack_driver: unused by JAX. If you want to select a non-default SVD driver, please
check :func:`jax.lax.linalg.svd` which provides such functionality.
Returns:
A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``.

View File

@ -145,6 +145,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgesvd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgesvd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgesvd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgesvd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_ssyevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dsyevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cheevd_ffi);

View File

@ -111,6 +111,10 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<svd::SVDType<DataType::F64>>(lapack_ptr("dgesdd"));
AssignKernelFn<svd::SVDType<DataType::C64>>(lapack_ptr("cgesdd"));
AssignKernelFn<svd::SVDType<DataType::C128>>(lapack_ptr("zgesdd"));
AssignKernelFn<svd::SVDQRType<DataType::F32>>(lapack_ptr("sgesvd"));
AssignKernelFn<svd::SVDQRType<DataType::F64>>(lapack_ptr("dgesvd"));
AssignKernelFn<svd::SVDQRType<DataType::C64>>(lapack_ptr("cgesvd"));
AssignKernelFn<svd::SVDQRType<DataType::C128>>(lapack_ptr("zgesvd"));
AssignKernelFn<RealSyevd<float>>(lapack_ptr("ssyevd"));
AssignKernelFn<RealSyevd<double>>(lapack_ptr("dsyevd"));
@ -274,6 +278,10 @@ nb::dict Registrations() {
dict["lapack_dgesdd_ffi"] = EncapsulateFunction(lapack_dgesdd_ffi);
dict["lapack_cgesdd_ffi"] = EncapsulateFunction(lapack_cgesdd_ffi);
dict["lapack_zgesdd_ffi"] = EncapsulateFunction(lapack_zgesdd_ffi);
dict["lapack_sgesvd_ffi"] = EncapsulateFunction(lapack_sgesvd_ffi);
dict["lapack_dgesvd_ffi"] = EncapsulateFunction(lapack_dgesvd_ffi);
dict["lapack_cgesvd_ffi"] = EncapsulateFunction(lapack_cgesvd_ffi);
dict["lapack_zgesvd_ffi"] = EncapsulateFunction(lapack_zgesvd_ffi);
dict["lapack_ssyevd_ffi"] = EncapsulateFunction(lapack_ssyevd_ffi);
dict["lapack_dsyevd_ffi"] = EncapsulateFunction(lapack_dsyevd_ffi);
dict["lapack_cheevd_ffi"] = EncapsulateFunction(lapack_cheevd_ffi);

View File

@ -893,6 +893,107 @@ static int64_t SvdGetWorkspaceSize(lapack_int x_rows, lapack_int x_cols,
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
}
template <ffi::DataType dtype>
static ffi::Error SvdQRKernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<ffi::ToReal(dtype)> singular_values,
ffi::ResultBuffer<dtype> u, ffi::ResultBuffer<dtype> vt,
ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode) {
if (mode == svd::ComputationMode::kComputeVtOverwriteXPartialU) [[unlikely]] {
return ffi::Error(
XLA_FFI_Error_Code_UNIMPLEMENTED,
"SVD: Current implementation does not support this computation mode");
}
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
auto* x_out_data = x_out->typed_data();
auto* singular_values_data = singular_values->typed_data();
auto* u_data = u->typed_data();
auto* vt_data = vt->typed_data();
auto* info_data = info->typed_data();
// Prepare LAPACK workspaces.
FFI_ASSIGN_OR_RETURN(
const auto work_size,
svd::SVDQRType<dtype>::GetWorkspaceSize(x_rows, x_cols, mode));
auto work_data = AllocateScratchMemory<dtype>(work_size);
using RealType = typename svd::SVDType<dtype>::RealType;
std::unique_ptr<RealType[]> rwork;
if constexpr (ffi::IsComplexType<dtype>()) {
FFI_ASSIGN_OR_RETURN(const auto rwork_size,
svd::GetRealWorkspaceSizeQR(x_rows, x_cols));
rwork = AllocateScratchMemory<ffi::ToReal(dtype)>(rwork_size);
}
CopyIfDiffBuffer(x, x_out);
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
auto mode_v = static_cast<char>(mode);
auto workspace_dim_v = work_size;
auto x_leading_dim_v = x_rows_v;
auto u_leading_dim_v = x_rows_v;
auto u_dims = u->dimensions().last(2);
auto vt_dims = vt->dimensions().last(2);
FFI_ASSIGN_OR_RETURN(auto vt_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
const int64_t x_out_step{x_rows * x_cols};
const int64_t singular_values_step{singular_values->dimensions().back()};
const int64_t u_step{u_dims.front() * u_dims.back()};
const int64_t vt_step{vt_leading_dim_v * vt_dims.back()};
for (int64_t i = 0; i < batch_count; ++i) {
if constexpr (ffi::IsComplexType<dtype>()) {
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data,
&x_leading_dim_v, singular_values_data, u_data,
&u_leading_dim_v, vt_data, &vt_leading_dim_v,
work_data.get(), &workspace_dim_v, rwork.get(),
info_data);
} else {
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data,
&x_leading_dim_v, singular_values_data, u_data,
&u_leading_dim_v, vt_data, &vt_leading_dim_v,
work_data.get(), &workspace_dim_v, info_data);
}
x_out_data += x_out_step;
singular_values_data += singular_values_step;
u_data += u_step;
vt_data += vt_step;
++info_data;
}
return ffi::Error::Success();
}
template <ffi::DataType dtype>
static absl::StatusOr<lapack_int> SvdQRGetWorkspaceSize(lapack_int x_rows,
lapack_int x_cols,
svd::ComputationMode mode) {
ffi::NativeType<dtype> optimal_size = {};
lapack_int info = 0;
lapack_int workspace_query = -1;
auto mode_v = static_cast<char>(mode);
auto x_leading_dim_v = x_rows;
auto u_leading_dim_v = x_rows;
auto vt_leading_dim_v = mode == svd::ComputationMode::kComputeFullUVt
? x_cols
: std::min(x_rows, x_cols);
if constexpr (ffi::IsComplexType<dtype>()) {
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows, &x_cols, nullptr,
&x_leading_dim_v, nullptr, nullptr,
&u_leading_dim_v, nullptr, &vt_leading_dim_v,
&optimal_size, &workspace_query, nullptr, &info);
} else {
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows, &x_cols, nullptr,
&x_leading_dim_v, nullptr, nullptr,
&u_leading_dim_v, nullptr, &vt_leading_dim_v,
&optimal_size, &workspace_query, &info);
}
return info == 0 ? MaybeCastNoOverflow<lapack_int>(std::real(optimal_size)) : -1;
}
} // namespace internal
template <ffi::DataType dtype>
@ -928,6 +1029,39 @@ SingularValueDecompositionComplex<dtype>::GetWorkspaceSize(
return internal::SvdGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
}
template <ffi::DataType dtype>
ffi::Error SingularValueDecompositionQR<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<dtype> singular_values, ffi::ResultBuffer<dtype> u,
ffi::ResultBuffer<dtype> vt, ffi::ResultBuffer<LapackIntDtype> info,
svd::ComputationMode mode) {
return internal::SvdQRKernel<dtype>(x, x_out, singular_values, u, vt, info,
mode);
}
template <ffi::DataType dtype>
ffi::Error SingularValueDecompositionQRComplex<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<ffi::ToReal(dtype)> singular_values,
ffi::ResultBuffer<dtype> u, ffi::ResultBuffer<dtype> vt,
ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode) {
return internal::SvdQRKernel<dtype>(x, x_out, singular_values, u, vt, info,
mode);
}
template <ffi::DataType dtype>
absl::StatusOr<lapack_int> SingularValueDecompositionQR<dtype>::GetWorkspaceSize(
lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) {
return internal::SvdQRGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
}
template <ffi::DataType dtype>
absl::StatusOr<lapack_int>
SingularValueDecompositionQRComplex<dtype>::GetWorkspaceSize(
lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) {
return internal::SvdQRGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
}
absl::StatusOr<lapack_int> svd::GetRealWorkspaceSize(
int64_t x_rows, int64_t x_cols, svd::ComputationMode mode) {
const auto min_dim = std::min(x_rows, x_cols);
@ -940,6 +1074,10 @@ absl::StatusOr<lapack_int> svd::GetRealWorkspaceSize(
2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim));
}
absl::StatusOr<lapack_int> svd::GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols) {
return CastNoOverflow<lapack_int>(5 * std::min(x_rows, x_cols));
}
absl::StatusOr<lapack_int> svd::GetIntWorkspaceSize(int64_t x_rows,
int64_t x_cols) {
return CastNoOverflow<lapack_int>(8 * std::min(x_rows, x_cols));
@ -950,6 +1088,11 @@ template struct SingularValueDecomposition<ffi::DataType::F64>;
template struct SingularValueDecompositionComplex<ffi::DataType::C64>;
template struct SingularValueDecompositionComplex<ffi::DataType::C128>;
template struct SingularValueDecompositionQR<ffi::DataType::F32>;
template struct SingularValueDecompositionQR<ffi::DataType::F64>;
template struct SingularValueDecompositionQRComplex<ffi::DataType::C64>;
template struct SingularValueDecompositionQRComplex<ffi::DataType::C128>;
//== Eigenvalues and eigenvectors ==//
// lapack syevd/heevd
@ -2179,6 +2322,30 @@ template struct TridiagonalSolver<ffi::DataType::C128>;
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Attr<svd::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_GESVD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SingularValueDecompositionQR<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*s*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*u*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*vt*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Attr<svd::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_GESVD_COMPLEX(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SingularValueDecompositionQRComplex<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*u*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*vt*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Attr<svd::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_SYEVD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, EigenvalueDecompositionSymmetric<data_type>::Kernel, \
@ -2332,6 +2499,11 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GESVD(lapack_sgesvd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GESVD(lapack_dgesvd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GESVD_COMPLEX(lapack_cgesvd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GESVD_COMPLEX(lapack_zgesvd_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_SYEVD(lapack_ssyevd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_SYEVD(lapack_dsyevd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_HEEVD(lapack_cheevd_ffi, ::xla::ffi::DataType::C64);
@ -2370,6 +2542,8 @@ JAX_CPU_DEFINE_GTSV(lapack_zgtsv_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_POTRF
#undef JAX_CPU_DEFINE_GESDD
#undef JAX_CPU_DEFINE_GESDD_COMPLEX
#undef JAX_CPU_DEFINE_GESVD
#undef JAX_CPU_DEFINE_GESVD_COMPLEX
#undef JAX_CPU_DEFINE_SYEVD
#undef JAX_CPU_DEFINE_HEEVD
#undef JAX_CPU_DEFINE_GEEV

View File

@ -387,6 +387,55 @@ struct SingularValueDecompositionComplex {
svd::ComputationMode mode);
};
template <::xla::ffi::DataType dtype>
struct SingularValueDecompositionQR {
static_assert(!::xla::ffi::IsComplexType<dtype>(),
"There exists a separate implementation for Complex types");
using ValueType = ::xla::ffi::NativeType<dtype>;
using RealType = ValueType;
using FnType = void(char* jobu, char* jobvt, lapack_int* m, lapack_int* n,
ValueType* a, lapack_int* lda, ValueType* s, ValueType* u,
lapack_int* ldu, ValueType* vt, lapack_int* ldvt,
ValueType* work, lapack_int* lwork, lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> singular_values,
::xla::ffi::ResultBuffer<dtype> u, ::xla::ffi::ResultBuffer<dtype> vt,
::xla::ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode);
static absl::StatusOr<lapack_int> GetWorkspaceSize(lapack_int x_rows,
lapack_int x_cols,
svd::ComputationMode mode);
};
template <::xla::ffi::DataType dtype>
struct SingularValueDecompositionQRComplex {
static_assert(::xla::ffi::IsComplexType<dtype>());
using ValueType = ::xla::ffi::NativeType<dtype>;
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
using FnType = void(char* jobu, char* jobvt, lapack_int* m, lapack_int* n,
ValueType* a, lapack_int* lda, RealType* s, ValueType* u,
lapack_int* ldu, ValueType* vt, lapack_int* ldvt,
ValueType* work, lapack_int* lwork, RealType* rwork,
lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> singular_values,
::xla::ffi::ResultBuffer<dtype> u, ::xla::ffi::ResultBuffer<dtype> vt,
::xla::ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode);
static absl::StatusOr<lapack_int> GetWorkspaceSize(lapack_int x_rows,
lapack_int x_cols,
svd::ComputationMode mode);
};
namespace svd {
template <::xla::ffi::DataType dtype>
@ -394,9 +443,15 @@ using SVDType = std::conditional_t<::xla::ffi::IsComplexType<dtype>(),
SingularValueDecompositionComplex<dtype>,
SingularValueDecomposition<dtype>>;
template <::xla::ffi::DataType dtype>
using SVDQRType = std::conditional_t<::xla::ffi::IsComplexType<dtype>(),
SingularValueDecompositionQRComplex<dtype>,
SingularValueDecompositionQR<dtype>>;
absl::StatusOr<lapack_int> GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols);
absl::StatusOr<lapack_int> GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols,
ComputationMode mode);
absl::StatusOr<lapack_int> GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols);
} // namespace svd
@ -817,6 +872,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesvd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesvd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesvd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesvd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssyevd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsyevd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cheevd_ffi);

View File

@ -61,6 +61,11 @@ jax::SingularValueDecomposition<ffi::DataType::F64>::FnType dgesdd_;
jax::SingularValueDecompositionComplex<ffi::DataType::C64>::FnType cgesdd_;
jax::SingularValueDecompositionComplex<ffi::DataType::C128>::FnType zgesdd_;
jax::SingularValueDecompositionQR<ffi::DataType::F32>::FnType sgesvd_;
jax::SingularValueDecompositionQR<ffi::DataType::F64>::FnType dgesvd_;
jax::SingularValueDecompositionQRComplex<ffi::DataType::C64>::FnType cgesvd_;
jax::SingularValueDecompositionQRComplex<ffi::DataType::C128>::FnType zgesvd_;
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F32>::FnType ssyevd_;
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F64>::FnType dsyevd_;
jax::EigenvalueDecompositionHermitian<ffi::DataType::C64>::FnType cheevd_;
@ -367,6 +372,13 @@ static auto init = []() -> int {
AssignKernelFn<SingularValueDecompositionComplex<ffi::DataType::C128>>(
zgesdd_);
AssignKernelFn<SingularValueDecompositionQR<ffi::DataType::F32>>(sgesvd_);
AssignKernelFn<SingularValueDecompositionQR<ffi::DataType::F64>>(dgesvd_);
AssignKernelFn<SingularValueDecompositionQRComplex<ffi::DataType::C64>>(
cgesvd_);
AssignKernelFn<SingularValueDecompositionQRComplex<ffi::DataType::C128>>(
zgesvd_);
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F32>>(ssyevd_);
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F64>>(dsyevd_);
AssignKernelFn<EigenvalueDecompositionHermitian<ffi::DataType::C64>>(cheevd_);