From e03fe3a06d1567bf738f859894a0fd98e6be4d6d Mon Sep 17 00:00:00 2001 From: Jan Naumann Date: Fri, 22 Nov 2024 11:32:38 +0100 Subject: [PATCH] Implement SVD algorithm based on QR for CPU targets In a recent jax release the SvdAlgorithm parameter has been added to the jax.lax.linalg.svd function. Currently, for CPU targets still only the divide and conquer algorithm from LAPACK is supported (gesdd). This commits adds the functionality to select the QR based algorithm on CPU as well. Mainly it addes the wrapper code to call the gesvd function of LAPACK using the FFI interface. Signed-off-by: Jan Naumann --- jax/_src/lax/linalg.py | 9 +- jax/_src/scipy/linalg.py | 3 +- jaxlib/cpu/cpu_kernels.cc | 4 + jaxlib/cpu/lapack.cc | 8 + jaxlib/cpu/lapack_kernels.cc | 174 ++++++++++++++++++++++ jaxlib/cpu/lapack_kernels.h | 59 ++++++++ jaxlib/cpu/lapack_kernels_using_lapack.cc | 12 ++ 7 files changed, 265 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 6f71fc579..bcce16d6a 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2069,10 +2069,13 @@ def _svd_cpu_gpu_lowering( compute_uv=compute_uv, ) if target_name_prefix == "cpu": - if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT: + if algorithm is None or algorithm == SvdAlgorithm.DEFAULT: + target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype) + elif algorithm == SvdAlgorithm.QR: + target_name = lapack.prepare_lapack_call("gesvd_ffi", operand_aval.dtype) + else: raise NotImplementedError( - "The SVD algorithm parameter is not implemented on CPU.") - target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype) + "The SVD Jacobi algorithm is not implemented on CPU.") mode = _svd_computation_attr(compute_uv, full_matrices) info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) if compute_uv: diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d40608027..9917cbaa0 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -271,7 +271,8 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, only the singular values ``s``. overwrite_a: unused by JAX check_finite: unused by JAX - lapack_driver: unused by JAX + lapack_driver: unused by JAX. If you want to select a non-default SVD driver, please + check :func:`jax.lax.linalg.svd` which provides such functionality. Returns: A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``. diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 7b8e6d728..6ed42496f 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -145,6 +145,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgesdd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgesdd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgesdd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgesdd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_sgesvd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgesvd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgesvd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgesvd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_ssyevd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dsyevd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cheevd_ffi); diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 1ede3c578..c10401977 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -111,6 +111,10 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dgesdd")); AssignKernelFn>(lapack_ptr("cgesdd")); AssignKernelFn>(lapack_ptr("zgesdd")); + AssignKernelFn>(lapack_ptr("sgesvd")); + AssignKernelFn>(lapack_ptr("dgesvd")); + AssignKernelFn>(lapack_ptr("cgesvd")); + AssignKernelFn>(lapack_ptr("zgesvd")); AssignKernelFn>(lapack_ptr("ssyevd")); AssignKernelFn>(lapack_ptr("dsyevd")); @@ -274,6 +278,10 @@ nb::dict Registrations() { dict["lapack_dgesdd_ffi"] = EncapsulateFunction(lapack_dgesdd_ffi); dict["lapack_cgesdd_ffi"] = EncapsulateFunction(lapack_cgesdd_ffi); dict["lapack_zgesdd_ffi"] = EncapsulateFunction(lapack_zgesdd_ffi); + dict["lapack_sgesvd_ffi"] = EncapsulateFunction(lapack_sgesvd_ffi); + dict["lapack_dgesvd_ffi"] = EncapsulateFunction(lapack_dgesvd_ffi); + dict["lapack_cgesvd_ffi"] = EncapsulateFunction(lapack_cgesvd_ffi); + dict["lapack_zgesvd_ffi"] = EncapsulateFunction(lapack_zgesvd_ffi); dict["lapack_ssyevd_ffi"] = EncapsulateFunction(lapack_ssyevd_ffi); dict["lapack_dsyevd_ffi"] = EncapsulateFunction(lapack_dsyevd_ffi); dict["lapack_cheevd_ffi"] = EncapsulateFunction(lapack_cheevd_ffi); diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index c56c1a98c..894ab13bb 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -893,6 +893,107 @@ static int64_t SvdGetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, return info == 0 ? static_cast(std::real(optimal_size)) : -1; } +template +static ffi::Error SvdQRKernel( + ffi::Buffer x, ffi::ResultBuffer x_out, + ffi::ResultBuffer singular_values, + ffi::ResultBuffer u, ffi::ResultBuffer vt, + ffi::ResultBuffer info, svd::ComputationMode mode) { + if (mode == svd::ComputationMode::kComputeVtOverwriteXPartialU) [[unlikely]] { + return ffi::Error( + XLA_FFI_Error_Code_UNIMPLEMENTED, + "SVD: Current implementation does not support this computation mode"); + } + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); + auto* x_out_data = x_out->typed_data(); + auto* singular_values_data = singular_values->typed_data(); + auto* u_data = u->typed_data(); + auto* vt_data = vt->typed_data(); + auto* info_data = info->typed_data(); + + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN( + const auto work_size, + svd::SVDQRType::GetWorkspaceSize(x_rows, x_cols, mode)); + auto work_data = AllocateScratchMemory(work_size); + using RealType = typename svd::SVDType::RealType; + std::unique_ptr rwork; + if constexpr (ffi::IsComplexType()) { + FFI_ASSIGN_OR_RETURN(const auto rwork_size, + svd::GetRealWorkspaceSizeQR(x_rows, x_cols)); + rwork = AllocateScratchMemory(rwork_size); + } + + CopyIfDiffBuffer(x, x_out); + + FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + auto mode_v = static_cast(mode); + auto workspace_dim_v = work_size; + auto x_leading_dim_v = x_rows_v; + auto u_leading_dim_v = x_rows_v; + + auto u_dims = u->dimensions().last(2); + auto vt_dims = vt->dimensions().last(2); + FFI_ASSIGN_OR_RETURN(auto vt_leading_dim_v, + MaybeCastNoOverflow(vt_dims.front())); + + const int64_t x_out_step{x_rows * x_cols}; + const int64_t singular_values_step{singular_values->dimensions().back()}; + const int64_t u_step{u_dims.front() * u_dims.back()}; + const int64_t vt_step{vt_leading_dim_v * vt_dims.back()}; + + for (int64_t i = 0; i < batch_count; ++i) { + if constexpr (ffi::IsComplexType()) { + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, + &x_leading_dim_v, singular_values_data, u_data, + &u_leading_dim_v, vt_data, &vt_leading_dim_v, + work_data.get(), &workspace_dim_v, rwork.get(), + info_data); + } else { + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, + &x_leading_dim_v, singular_values_data, u_data, + &u_leading_dim_v, vt_data, &vt_leading_dim_v, + work_data.get(), &workspace_dim_v, info_data); + } + x_out_data += x_out_step; + singular_values_data += singular_values_step; + u_data += u_step; + vt_data += vt_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template +static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + svd::ComputationMode mode) { + ffi::NativeType optimal_size = {}; + lapack_int info = 0; + lapack_int workspace_query = -1; + + auto mode_v = static_cast(mode); + auto x_leading_dim_v = x_rows; + auto u_leading_dim_v = x_rows; + auto vt_leading_dim_v = mode == svd::ComputationMode::kComputeFullUVt + ? x_cols + : std::min(x_rows, x_cols); + if constexpr (ffi::IsComplexType()) { + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows, &x_cols, nullptr, + &x_leading_dim_v, nullptr, nullptr, + &u_leading_dim_v, nullptr, &vt_leading_dim_v, + &optimal_size, &workspace_query, nullptr, &info); + } else { + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows, &x_cols, nullptr, + &x_leading_dim_v, nullptr, nullptr, + &u_leading_dim_v, nullptr, &vt_leading_dim_v, + &optimal_size, &workspace_query, &info); + } + return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) : -1; +} + } // namespace internal template @@ -928,6 +1029,39 @@ SingularValueDecompositionComplex::GetWorkspaceSize( return internal::SvdGetWorkspaceSize(x_rows, x_cols, mode); } +template +ffi::Error SingularValueDecompositionQR::Kernel( + ffi::Buffer x, ffi::ResultBuffer x_out, + ffi::ResultBuffer singular_values, ffi::ResultBuffer u, + ffi::ResultBuffer vt, ffi::ResultBuffer info, + svd::ComputationMode mode) { + return internal::SvdQRKernel(x, x_out, singular_values, u, vt, info, + mode); +} + +template +ffi::Error SingularValueDecompositionQRComplex::Kernel( + ffi::Buffer x, ffi::ResultBuffer x_out, + ffi::ResultBuffer singular_values, + ffi::ResultBuffer u, ffi::ResultBuffer vt, + ffi::ResultBuffer info, svd::ComputationMode mode) { + return internal::SvdQRKernel(x, x_out, singular_values, u, vt, info, + mode); +} + +template +absl::StatusOr SingularValueDecompositionQR::GetWorkspaceSize( + lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { + return internal::SvdQRGetWorkspaceSize(x_rows, x_cols, mode); +} + +template +absl::StatusOr +SingularValueDecompositionQRComplex::GetWorkspaceSize( + lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { + return internal::SvdQRGetWorkspaceSize(x_rows, x_cols, mode); +} + absl::StatusOr svd::GetRealWorkspaceSize( int64_t x_rows, int64_t x_cols, svd::ComputationMode mode) { const auto min_dim = std::min(x_rows, x_cols); @@ -940,6 +1074,10 @@ absl::StatusOr svd::GetRealWorkspaceSize( 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim)); } +absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols) { + return CastNoOverflow(5 * std::min(x_rows, x_cols)); +} + absl::StatusOr svd::GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols) { return CastNoOverflow(8 * std::min(x_rows, x_cols)); @@ -950,6 +1088,11 @@ template struct SingularValueDecomposition; template struct SingularValueDecompositionComplex; template struct SingularValueDecompositionComplex; +template struct SingularValueDecompositionQR; +template struct SingularValueDecompositionQR; +template struct SingularValueDecompositionQRComplex; +template struct SingularValueDecompositionQRComplex; + //== Eigenvalues and eigenvectors ==// // lapack syevd/heevd @@ -2179,6 +2322,30 @@ template struct TridiagonalSolver; .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) +#define JAX_CPU_DEFINE_GESVD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecompositionQR::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Attr("mode")) + +#define JAX_CPU_DEFINE_GESVD_COMPLEX(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecompositionQRComplex::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Attr("mode")) + #define JAX_CPU_DEFINE_SYEVD(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ name, EigenvalueDecompositionSymmetric::Kernel, \ @@ -2332,6 +2499,11 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128); +JAX_CPU_DEFINE_GESVD(lapack_sgesvd_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GESVD(lapack_dgesvd_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GESVD_COMPLEX(lapack_cgesvd_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GESVD_COMPLEX(lapack_zgesvd_ffi, ::xla::ffi::DataType::C128); + JAX_CPU_DEFINE_SYEVD(lapack_ssyevd_ffi, ::xla::ffi::DataType::F32); JAX_CPU_DEFINE_SYEVD(lapack_dsyevd_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_HEEVD(lapack_cheevd_ffi, ::xla::ffi::DataType::C64); @@ -2370,6 +2542,8 @@ JAX_CPU_DEFINE_GTSV(lapack_zgtsv_ffi, ::xla::ffi::DataType::C128); #undef JAX_CPU_DEFINE_POTRF #undef JAX_CPU_DEFINE_GESDD #undef JAX_CPU_DEFINE_GESDD_COMPLEX +#undef JAX_CPU_DEFINE_GESVD +#undef JAX_CPU_DEFINE_GESVD_COMPLEX #undef JAX_CPU_DEFINE_SYEVD #undef JAX_CPU_DEFINE_HEEVD #undef JAX_CPU_DEFINE_GEEV diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 422bb9dd3..d94b5af61 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -387,6 +387,55 @@ struct SingularValueDecompositionComplex { svd::ComputationMode mode); }; +template <::xla::ffi::DataType dtype> +struct SingularValueDecompositionQR { + static_assert(!::xla::ffi::IsComplexType(), + "There exists a separate implementation for Complex types"); + using ValueType = ::xla::ffi::NativeType; + using RealType = ValueType; + using FnType = void(char* jobu, char* jobvt, lapack_int* m, lapack_int* n, + ValueType* a, lapack_int* lda, ValueType* s, ValueType* u, + lapack_int* ldu, ValueType* vt, lapack_int* ldvt, + ValueType* work, lapack_int* lwork, lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer singular_values, + ::xla::ffi::ResultBuffer u, ::xla::ffi::ResultBuffer vt, + ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); + + static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + svd::ComputationMode mode); +}; + +template <::xla::ffi::DataType dtype> +struct SingularValueDecompositionQRComplex { + static_assert(::xla::ffi::IsComplexType()); + + using ValueType = ::xla::ffi::NativeType; + using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>; + using FnType = void(char* jobu, char* jobvt, lapack_int* m, lapack_int* n, + ValueType* a, lapack_int* lda, RealType* s, ValueType* u, + lapack_int* ldu, ValueType* vt, lapack_int* ldvt, + ValueType* work, lapack_int* lwork, RealType* rwork, + lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> singular_values, + ::xla::ffi::ResultBuffer u, ::xla::ffi::ResultBuffer vt, + ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); + + static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + svd::ComputationMode mode); +}; + namespace svd { template <::xla::ffi::DataType dtype> @@ -394,9 +443,15 @@ using SVDType = std::conditional_t<::xla::ffi::IsComplexType(), SingularValueDecompositionComplex, SingularValueDecomposition>; +template <::xla::ffi::DataType dtype> +using SVDQRType = std::conditional_t<::xla::ffi::IsComplexType(), + SingularValueDecompositionQRComplex, + SingularValueDecompositionQR>; + absl::StatusOr GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); absl::StatusOr GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, ComputationMode mode); +absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols); } // namespace svd @@ -817,6 +872,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesvd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesvd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesvd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesvd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssyevd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsyevd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cheevd_ffi); diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 1b03c6ad0..3c8ddf11c 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -61,6 +61,11 @@ jax::SingularValueDecomposition::FnType dgesdd_; jax::SingularValueDecompositionComplex::FnType cgesdd_; jax::SingularValueDecompositionComplex::FnType zgesdd_; +jax::SingularValueDecompositionQR::FnType sgesvd_; +jax::SingularValueDecompositionQR::FnType dgesvd_; +jax::SingularValueDecompositionQRComplex::FnType cgesvd_; +jax::SingularValueDecompositionQRComplex::FnType zgesvd_; + jax::EigenvalueDecompositionSymmetric::FnType ssyevd_; jax::EigenvalueDecompositionSymmetric::FnType dsyevd_; jax::EigenvalueDecompositionHermitian::FnType cheevd_; @@ -367,6 +372,13 @@ static auto init = []() -> int { AssignKernelFn>( zgesdd_); + AssignKernelFn>(sgesvd_); + AssignKernelFn>(dgesvd_); + AssignKernelFn>( + cgesvd_); + AssignKernelFn>( + zgesvd_); + AssignKernelFn>(ssyevd_); AssignKernelFn>(dsyevd_); AssignKernelFn>(cheevd_);