mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Implement SVD algorithm based on QR for CPU targets
In a recent jax release the SvdAlgorithm parameter has been added to the jax.lax.linalg.svd function. Currently, for CPU targets still only the divide and conquer algorithm from LAPACK is supported (gesdd). This commits adds the functionality to select the QR based algorithm on CPU as well. Mainly it addes the wrapper code to call the gesvd function of LAPACK using the FFI interface. Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
This commit is contained in:
parent
a75bea51de
commit
e03fe3a06d
@ -2069,10 +2069,13 @@ def _svd_cpu_gpu_lowering(
|
||||
compute_uv=compute_uv,
|
||||
)
|
||||
if target_name_prefix == "cpu":
|
||||
if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT:
|
||||
if algorithm is None or algorithm == SvdAlgorithm.DEFAULT:
|
||||
target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype)
|
||||
elif algorithm == SvdAlgorithm.QR:
|
||||
target_name = lapack.prepare_lapack_call("gesvd_ffi", operand_aval.dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The SVD algorithm parameter is not implemented on CPU.")
|
||||
target_name = lapack.prepare_lapack_call("gesdd_ffi", operand_aval.dtype)
|
||||
"The SVD Jacobi algorithm is not implemented on CPU.")
|
||||
mode = _svd_computation_attr(compute_uv, full_matrices)
|
||||
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
|
||||
if compute_uv:
|
||||
|
@ -271,7 +271,8 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
|
||||
only the singular values ``s``.
|
||||
overwrite_a: unused by JAX
|
||||
check_finite: unused by JAX
|
||||
lapack_driver: unused by JAX
|
||||
lapack_driver: unused by JAX. If you want to select a non-default SVD driver, please
|
||||
check :func:`jax.lax.linalg.svd` which provides such functionality.
|
||||
|
||||
Returns:
|
||||
A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``.
|
||||
|
@ -145,6 +145,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sgesvd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgesvd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgesvd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgesvd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_ssyevd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dsyevd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cheevd_ffi);
|
||||
|
@ -111,6 +111,10 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<svd::SVDType<DataType::F64>>(lapack_ptr("dgesdd"));
|
||||
AssignKernelFn<svd::SVDType<DataType::C64>>(lapack_ptr("cgesdd"));
|
||||
AssignKernelFn<svd::SVDType<DataType::C128>>(lapack_ptr("zgesdd"));
|
||||
AssignKernelFn<svd::SVDQRType<DataType::F32>>(lapack_ptr("sgesvd"));
|
||||
AssignKernelFn<svd::SVDQRType<DataType::F64>>(lapack_ptr("dgesvd"));
|
||||
AssignKernelFn<svd::SVDQRType<DataType::C64>>(lapack_ptr("cgesvd"));
|
||||
AssignKernelFn<svd::SVDQRType<DataType::C128>>(lapack_ptr("zgesvd"));
|
||||
|
||||
AssignKernelFn<RealSyevd<float>>(lapack_ptr("ssyevd"));
|
||||
AssignKernelFn<RealSyevd<double>>(lapack_ptr("dsyevd"));
|
||||
@ -274,6 +278,10 @@ nb::dict Registrations() {
|
||||
dict["lapack_dgesdd_ffi"] = EncapsulateFunction(lapack_dgesdd_ffi);
|
||||
dict["lapack_cgesdd_ffi"] = EncapsulateFunction(lapack_cgesdd_ffi);
|
||||
dict["lapack_zgesdd_ffi"] = EncapsulateFunction(lapack_zgesdd_ffi);
|
||||
dict["lapack_sgesvd_ffi"] = EncapsulateFunction(lapack_sgesvd_ffi);
|
||||
dict["lapack_dgesvd_ffi"] = EncapsulateFunction(lapack_dgesvd_ffi);
|
||||
dict["lapack_cgesvd_ffi"] = EncapsulateFunction(lapack_cgesvd_ffi);
|
||||
dict["lapack_zgesvd_ffi"] = EncapsulateFunction(lapack_zgesvd_ffi);
|
||||
dict["lapack_ssyevd_ffi"] = EncapsulateFunction(lapack_ssyevd_ffi);
|
||||
dict["lapack_dsyevd_ffi"] = EncapsulateFunction(lapack_dsyevd_ffi);
|
||||
dict["lapack_cheevd_ffi"] = EncapsulateFunction(lapack_cheevd_ffi);
|
||||
|
@ -893,6 +893,107 @@ static int64_t SvdGetWorkspaceSize(lapack_int x_rows, lapack_int x_cols,
|
||||
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
static ffi::Error SvdQRKernel(
|
||||
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> singular_values,
|
||||
ffi::ResultBuffer<dtype> u, ffi::ResultBuffer<dtype> vt,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode) {
|
||||
if (mode == svd::ComputationMode::kComputeVtOverwriteXPartialU) [[unlikely]] {
|
||||
return ffi::Error(
|
||||
XLA_FFI_Error_Code_UNIMPLEMENTED,
|
||||
"SVD: Current implementation does not support this computation mode");
|
||||
}
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* singular_values_data = singular_values->typed_data();
|
||||
auto* u_data = u->typed_data();
|
||||
auto* vt_data = vt->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
|
||||
// Prepare LAPACK workspaces.
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
const auto work_size,
|
||||
svd::SVDQRType<dtype>::GetWorkspaceSize(x_rows, x_cols, mode));
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size);
|
||||
using RealType = typename svd::SVDType<dtype>::RealType;
|
||||
std::unique_ptr<RealType[]> rwork;
|
||||
if constexpr (ffi::IsComplexType<dtype>()) {
|
||||
FFI_ASSIGN_OR_RETURN(const auto rwork_size,
|
||||
svd::GetRealWorkspaceSizeQR(x_rows, x_cols));
|
||||
rwork = AllocateScratchMemory<ffi::ToReal(dtype)>(rwork_size);
|
||||
}
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto workspace_dim_v = work_size;
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
auto u_leading_dim_v = x_rows_v;
|
||||
|
||||
auto u_dims = u->dimensions().last(2);
|
||||
auto vt_dims = vt->dimensions().last(2);
|
||||
FFI_ASSIGN_OR_RETURN(auto vt_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
const int64_t singular_values_step{singular_values->dimensions().back()};
|
||||
const int64_t u_step{u_dims.front() * u_dims.back()};
|
||||
const int64_t vt_step{vt_leading_dim_v * vt_dims.back()};
|
||||
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
if constexpr (ffi::IsComplexType<dtype>()) {
|
||||
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data,
|
||||
&x_leading_dim_v, singular_values_data, u_data,
|
||||
&u_leading_dim_v, vt_data, &vt_leading_dim_v,
|
||||
work_data.get(), &workspace_dim_v, rwork.get(),
|
||||
info_data);
|
||||
} else {
|
||||
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data,
|
||||
&x_leading_dim_v, singular_values_data, u_data,
|
||||
&u_leading_dim_v, vt_data, &vt_leading_dim_v,
|
||||
work_data.get(), &workspace_dim_v, info_data);
|
||||
}
|
||||
x_out_data += x_out_step;
|
||||
singular_values_data += singular_values_step;
|
||||
u_data += u_step;
|
||||
vt_data += vt_step;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
static absl::StatusOr<lapack_int> SvdQRGetWorkspaceSize(lapack_int x_rows,
|
||||
lapack_int x_cols,
|
||||
svd::ComputationMode mode) {
|
||||
ffi::NativeType<dtype> optimal_size = {};
|
||||
lapack_int info = 0;
|
||||
lapack_int workspace_query = -1;
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto x_leading_dim_v = x_rows;
|
||||
auto u_leading_dim_v = x_rows;
|
||||
auto vt_leading_dim_v = mode == svd::ComputationMode::kComputeFullUVt
|
||||
? x_cols
|
||||
: std::min(x_rows, x_cols);
|
||||
if constexpr (ffi::IsComplexType<dtype>()) {
|
||||
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows, &x_cols, nullptr,
|
||||
&x_leading_dim_v, nullptr, nullptr,
|
||||
&u_leading_dim_v, nullptr, &vt_leading_dim_v,
|
||||
&optimal_size, &workspace_query, nullptr, &info);
|
||||
} else {
|
||||
svd::SVDQRType<dtype>::fn(&mode_v, &mode_v, &x_rows, &x_cols, nullptr,
|
||||
&x_leading_dim_v, nullptr, nullptr,
|
||||
&u_leading_dim_v, nullptr, &vt_leading_dim_v,
|
||||
&optimal_size, &workspace_query, &info);
|
||||
}
|
||||
return info == 0 ? MaybeCastNoOverflow<lapack_int>(std::real(optimal_size)) : -1;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
@ -928,6 +1029,39 @@ SingularValueDecompositionComplex<dtype>::GetWorkspaceSize(
|
||||
return internal::SvdGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error SingularValueDecompositionQR<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<dtype> singular_values, ffi::ResultBuffer<dtype> u,
|
||||
ffi::ResultBuffer<dtype> vt, ffi::ResultBuffer<LapackIntDtype> info,
|
||||
svd::ComputationMode mode) {
|
||||
return internal::SvdQRKernel<dtype>(x, x_out, singular_values, u, vt, info,
|
||||
mode);
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error SingularValueDecompositionQRComplex<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> singular_values,
|
||||
ffi::ResultBuffer<dtype> u, ffi::ResultBuffer<dtype> vt,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode) {
|
||||
return internal::SvdQRKernel<dtype>(x, x_out, singular_values, u, vt, info,
|
||||
mode);
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
absl::StatusOr<lapack_int> SingularValueDecompositionQR<dtype>::GetWorkspaceSize(
|
||||
lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) {
|
||||
return internal::SvdQRGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
absl::StatusOr<lapack_int>
|
||||
SingularValueDecompositionQRComplex<dtype>::GetWorkspaceSize(
|
||||
lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) {
|
||||
return internal::SvdQRGetWorkspaceSize<dtype>(x_rows, x_cols, mode);
|
||||
}
|
||||
|
||||
absl::StatusOr<lapack_int> svd::GetRealWorkspaceSize(
|
||||
int64_t x_rows, int64_t x_cols, svd::ComputationMode mode) {
|
||||
const auto min_dim = std::min(x_rows, x_cols);
|
||||
@ -940,6 +1074,10 @@ absl::StatusOr<lapack_int> svd::GetRealWorkspaceSize(
|
||||
2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim));
|
||||
}
|
||||
|
||||
absl::StatusOr<lapack_int> svd::GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols) {
|
||||
return CastNoOverflow<lapack_int>(5 * std::min(x_rows, x_cols));
|
||||
}
|
||||
|
||||
absl::StatusOr<lapack_int> svd::GetIntWorkspaceSize(int64_t x_rows,
|
||||
int64_t x_cols) {
|
||||
return CastNoOverflow<lapack_int>(8 * std::min(x_rows, x_cols));
|
||||
@ -950,6 +1088,11 @@ template struct SingularValueDecomposition<ffi::DataType::F64>;
|
||||
template struct SingularValueDecompositionComplex<ffi::DataType::C64>;
|
||||
template struct SingularValueDecompositionComplex<ffi::DataType::C128>;
|
||||
|
||||
template struct SingularValueDecompositionQR<ffi::DataType::F32>;
|
||||
template struct SingularValueDecompositionQR<ffi::DataType::F64>;
|
||||
template struct SingularValueDecompositionQRComplex<ffi::DataType::C64>;
|
||||
template struct SingularValueDecompositionQRComplex<ffi::DataType::C128>;
|
||||
|
||||
//== Eigenvalues and eigenvectors ==//
|
||||
|
||||
// lapack syevd/heevd
|
||||
@ -2179,6 +2322,30 @@ template struct TridiagonalSolver<ffi::DataType::C128>;
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Attr<svd::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_GESVD(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, SingularValueDecompositionQR<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*s*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*u*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*vt*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Attr<svd::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_GESVD_COMPLEX(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, SingularValueDecompositionQRComplex<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*u*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*vt*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Attr<svd::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_SYEVD(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, EigenvalueDecompositionSymmetric<data_type>::Kernel, \
|
||||
@ -2332,6 +2499,11 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_GESVD(lapack_sgesvd_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_GESVD(lapack_dgesvd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GESVD_COMPLEX(lapack_cgesvd_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GESVD_COMPLEX(lapack_zgesvd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_SYEVD(lapack_ssyevd_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_SYEVD(lapack_dsyevd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_HEEVD(lapack_cheevd_ffi, ::xla::ffi::DataType::C64);
|
||||
@ -2370,6 +2542,8 @@ JAX_CPU_DEFINE_GTSV(lapack_zgtsv_ffi, ::xla::ffi::DataType::C128);
|
||||
#undef JAX_CPU_DEFINE_POTRF
|
||||
#undef JAX_CPU_DEFINE_GESDD
|
||||
#undef JAX_CPU_DEFINE_GESDD_COMPLEX
|
||||
#undef JAX_CPU_DEFINE_GESVD
|
||||
#undef JAX_CPU_DEFINE_GESVD_COMPLEX
|
||||
#undef JAX_CPU_DEFINE_SYEVD
|
||||
#undef JAX_CPU_DEFINE_HEEVD
|
||||
#undef JAX_CPU_DEFINE_GEEV
|
||||
|
@ -387,6 +387,55 @@ struct SingularValueDecompositionComplex {
|
||||
svd::ComputationMode mode);
|
||||
};
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct SingularValueDecompositionQR {
|
||||
static_assert(!::xla::ffi::IsComplexType<dtype>(),
|
||||
"There exists a separate implementation for Complex types");
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using RealType = ValueType;
|
||||
using FnType = void(char* jobu, char* jobvt, lapack_int* m, lapack_int* n,
|
||||
ValueType* a, lapack_int* lda, ValueType* s, ValueType* u,
|
||||
lapack_int* ldu, ValueType* vt, lapack_int* ldvt,
|
||||
ValueType* work, lapack_int* lwork, lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<dtype> singular_values,
|
||||
::xla::ffi::ResultBuffer<dtype> u, ::xla::ffi::ResultBuffer<dtype> vt,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode);
|
||||
|
||||
static absl::StatusOr<lapack_int> GetWorkspaceSize(lapack_int x_rows,
|
||||
lapack_int x_cols,
|
||||
svd::ComputationMode mode);
|
||||
};
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct SingularValueDecompositionQRComplex {
|
||||
static_assert(::xla::ffi::IsComplexType<dtype>());
|
||||
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
|
||||
using FnType = void(char* jobu, char* jobvt, lapack_int* m, lapack_int* n,
|
||||
ValueType* a, lapack_int* lda, RealType* s, ValueType* u,
|
||||
lapack_int* ldu, ValueType* vt, lapack_int* ldvt,
|
||||
ValueType* work, lapack_int* lwork, RealType* rwork,
|
||||
lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> singular_values,
|
||||
::xla::ffi::ResultBuffer<dtype> u, ::xla::ffi::ResultBuffer<dtype> vt,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info, svd::ComputationMode mode);
|
||||
|
||||
static absl::StatusOr<lapack_int> GetWorkspaceSize(lapack_int x_rows,
|
||||
lapack_int x_cols,
|
||||
svd::ComputationMode mode);
|
||||
};
|
||||
|
||||
namespace svd {
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
@ -394,9 +443,15 @@ using SVDType = std::conditional_t<::xla::ffi::IsComplexType<dtype>(),
|
||||
SingularValueDecompositionComplex<dtype>,
|
||||
SingularValueDecomposition<dtype>>;
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
using SVDQRType = std::conditional_t<::xla::ffi::IsComplexType<dtype>(),
|
||||
SingularValueDecompositionQRComplex<dtype>,
|
||||
SingularValueDecompositionQR<dtype>>;
|
||||
|
||||
absl::StatusOr<lapack_int> GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols);
|
||||
absl::StatusOr<lapack_int> GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols,
|
||||
ComputationMode mode);
|
||||
absl::StatusOr<lapack_int> GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols);
|
||||
|
||||
} // namespace svd
|
||||
|
||||
@ -817,6 +872,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesvd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesvd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesvd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesvd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssyevd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsyevd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cheevd_ffi);
|
||||
|
@ -61,6 +61,11 @@ jax::SingularValueDecomposition<ffi::DataType::F64>::FnType dgesdd_;
|
||||
jax::SingularValueDecompositionComplex<ffi::DataType::C64>::FnType cgesdd_;
|
||||
jax::SingularValueDecompositionComplex<ffi::DataType::C128>::FnType zgesdd_;
|
||||
|
||||
jax::SingularValueDecompositionQR<ffi::DataType::F32>::FnType sgesvd_;
|
||||
jax::SingularValueDecompositionQR<ffi::DataType::F64>::FnType dgesvd_;
|
||||
jax::SingularValueDecompositionQRComplex<ffi::DataType::C64>::FnType cgesvd_;
|
||||
jax::SingularValueDecompositionQRComplex<ffi::DataType::C128>::FnType zgesvd_;
|
||||
|
||||
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F32>::FnType ssyevd_;
|
||||
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F64>::FnType dsyevd_;
|
||||
jax::EigenvalueDecompositionHermitian<ffi::DataType::C64>::FnType cheevd_;
|
||||
@ -367,6 +372,13 @@ static auto init = []() -> int {
|
||||
AssignKernelFn<SingularValueDecompositionComplex<ffi::DataType::C128>>(
|
||||
zgesdd_);
|
||||
|
||||
AssignKernelFn<SingularValueDecompositionQR<ffi::DataType::F32>>(sgesvd_);
|
||||
AssignKernelFn<SingularValueDecompositionQR<ffi::DataType::F64>>(dgesvd_);
|
||||
AssignKernelFn<SingularValueDecompositionQRComplex<ffi::DataType::C64>>(
|
||||
cgesvd_);
|
||||
AssignKernelFn<SingularValueDecompositionQRComplex<ffi::DataType::C128>>(
|
||||
zgesvd_);
|
||||
|
||||
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F32>>(ssyevd_);
|
||||
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F64>>(dsyevd_);
|
||||
AssignKernelFn<EigenvalueDecompositionHermitian<ffi::DataType::C64>>(cheevd_);
|
||||
|
Loading…
x
Reference in New Issue
Block a user