Determine LAPACK workspace during Eigenvalue Kernels runtime

PiperOrigin-RevId: 666285759
This commit is contained in:
Paweł Paruzel 2024-08-22 04:09:02 -07:00 committed by jax authors
parent a72d46c549
commit 4786930a4c
4 changed files with 100 additions and 136 deletions

View File

@ -49,11 +49,7 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr
# FFI Kernel LAPACK Workspace Size Queries
def heevd_rwork_size_ffi(n: int) -> int: ...
def heevd_work_size_ffi(n: int) -> int: ...
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def syevd_iwork_size_ffi(n: int) -> int: ...
def syevd_work_size_ffi(n: int) -> int: ...

View File

@ -37,14 +37,6 @@ svd::ComputationMode GetSvdComputationMode(bool job_opt_compute_uv,
return svd::ComputationMode::kComputeFullUVt;
}
// Due to enforced kComputeEigenvectors, this assumes a larger workspace size.
// Could be improved to more accurately estimate the expected size based on the
// eig::ComputationMode value.
template <lapack_int (&f)(int64_t, eig::ComputationMode)>
inline constexpr auto BoundWithEigvecs = +[](lapack_int n) {
return f(n, eig::ComputationMode::kComputeEigenvectors);
};
void GetLapackKernelsFromScipy() {
static bool initialized = false; // Protected by GIL
if (initialized) return;
@ -348,14 +340,6 @@ NB_MODULE(_lapack, m) {
m.def("lapack_zungqr_workspace_ffi",
&OrthogonalQr<DataType::C128>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"), nb::arg("k"));
m.def("syevd_work_size_ffi", BoundWithEigvecs<eig::GetWorkspaceSize>,
nb::arg("n"));
m.def("syevd_iwork_size_ffi", BoundWithEigvecs<eig::GetIntWorkspaceSize>,
nb::arg("n"));
m.def("heevd_work_size_ffi", BoundWithEigvecs<eig::GetComplexWorkspaceSize>,
nb::arg("n"));
m.def("heevd_rwork_size_ffi", BoundWithEigvecs<eig::GetRealWorkspaceSize>,
nb::arg("n"));
}
} // namespace

View File

@ -954,21 +954,24 @@ template struct ComplexHeevd<std::complex<double>>;
// FFI Kernel
lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) {
absl::StatusOr<lapack_int> eig::GetWorkspaceSize(int64_t x_cols,
ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return CastNoOverflow<lapack_int>(2 * x_cols + 1);
return MaybeCastNoOverflow<lapack_int>(2 * x_cols + 1);
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(1 + 6 * x_cols + 2 * x_cols * x_cols);
return MaybeCastNoOverflow<lapack_int>(1 + 6 * x_cols +
2 * x_cols * x_cols);
}
}
lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) {
absl::StatusOr<lapack_int> eig::GetIntWorkspaceSize(int64_t x_cols,
ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return 1;
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(3 + 5 * x_cols);
return MaybeCastNoOverflow<lapack_int>(3 + 5 * x_cols);
}
}
@ -976,34 +979,34 @@ template <ffi::DataType dtype>
ffi::Error EigenvalueDecompositionSymmetric<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> eigenvalues,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
ffi::ResultBuffer<LapackIntDtype> info, eig::ComputationMode mode) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
auto* x_out_data = x_out->typed_data();
auto* eigenvalues_data = eigenvalues->typed_data();
auto* info_data = info->typed_data();
auto* work_data = work->typed_data();
auto* iwork_data = iwork->typed_data();
CopyIfDiffBuffer(x, x_out);
auto mode_v = static_cast<char>(mode);
auto uplo_v = static_cast<char>(uplo);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
work->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
iwork->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
// Prepare LAPACK workspaces.
FFI_ASSIGN_OR_RETURN(lapack_int work_size_v,
eig::GetWorkspaceSize(x_cols, mode));
FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v,
eig::GetIntWorkspaceSize(x_cols, mode));
auto work_data = AllocateScratchMemory<dtype>(work_size_v);
auto iwork_data = AllocateScratchMemory<LapackIntDtype>(iwork_size_v);
const int64_t x_out_step{x_cols * x_cols};
const int64_t eigenvalues_step{x_cols};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
eigenvalues_data, work_data, &workspace_dim_v, iwork_data,
&iworkspace_dim_v, info_data);
eigenvalues_data, work_data.get(), &work_size_v, iwork_data.get(),
&iwork_size_v, info_data);
x_out_data += x_out_step;
eigenvalues_data += eigenvalues_step;
++info_data;
@ -1013,21 +1016,24 @@ ffi::Error EigenvalueDecompositionSymmetric<dtype>::Kernel(
namespace eig {
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) {
absl::StatusOr<lapack_int> GetComplexWorkspaceSize(int64_t x_cols,
ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return CastNoOverflow<lapack_int>(x_cols + 1);
return MaybeCastNoOverflow<lapack_int>(x_cols + 1);
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(2 * x_cols + x_cols * x_cols);
return MaybeCastNoOverflow<lapack_int>(2 * x_cols + x_cols * x_cols);
}
}
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) {
absl::StatusOr<lapack_int> GetRealWorkspaceSize(int64_t x_cols,
ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return CastNoOverflow<lapack_int>(std::max(x_cols, int64_t{1}));
return MaybeCastNoOverflow<lapack_int>(std::max(x_cols, int64_t{1}));
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(1 + 5 * x_cols + 2 * x_cols * x_cols);
return MaybeCastNoOverflow<lapack_int>(1 + 5 * x_cols +
2 * x_cols * x_cols);
}
}
@ -1038,37 +1044,37 @@ ffi::Error EigenvalueDecompositionHermitian<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<ffi::ToReal(dtype)> eigenvalues,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork,
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
ffi::ResultBuffer<LapackIntDtype> info, eig::ComputationMode mode) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
auto* x_out_data = x_out->typed_data();
auto* eigenvalues_data = eigenvalues->typed_data();
auto* info_data = info->typed_data();
auto* work_data = work->typed_data();
auto* iwork_data = iwork->typed_data();
CopyIfDiffBuffer(x, x_out);
auto mode_v = static_cast<char>(mode);
auto uplo_v = static_cast<char>(uplo);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
work->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
rwork->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
iwork->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
// Prepare LAPACK workspaces.
FFI_ASSIGN_OR_RETURN(lapack_int work_size_v,
eig::GetComplexWorkspaceSize(x_cols, mode));
FFI_ASSIGN_OR_RETURN(lapack_int rwork_size_v,
eig::GetRealWorkspaceSize(x_cols, mode));
FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v,
eig::GetIntWorkspaceSize(x_cols, mode));
auto work_data = AllocateScratchMemory<dtype>(work_size_v);
auto iwork_data = AllocateScratchMemory<LapackIntDtype>(iwork_size_v);
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(rwork_size_v);
const int64_t x_out_step{x_cols * x_cols};
const int64_t eigenvalues_step{x_cols};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(),
&rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data);
eigenvalues_data, work_data.get(), &work_size_v, rwork_data.get(),
&rwork_size_v, iwork_data.get(), &iwork_size_v, info_data);
x_out_data += x_out_step;
eigenvalues_data += eigenvalues_step;
++info_data;
@ -1265,16 +1271,11 @@ ffi::Error EigenvalueDecomposition<dtype>::Kernel(
ffi::ResultBuffer<dtype> eigvals_imag,
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_left,
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_right,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_left,
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_right) {
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
const auto* x_data = x.typed_data();
auto* x_work_data = x_work->typed_data();
auto* work_eigvecs_left_data = work_eigvecs_left->typed_data();
auto* work_eigvecs_right_data = work_eigvecs_right->typed_data();
auto* eigvecs_left_data = eigvecs_left->typed_data();
auto* eigvecs_right_data = eigvecs_right->typed_data();
auto* eigvals_real_data = eigvals_real->typed_data();
@ -1284,43 +1285,45 @@ ffi::Error EigenvalueDecomposition<dtype>::Kernel(
auto compute_left_v = static_cast<char>(compute_left);
auto compute_right_v = static_cast<char>(compute_right);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
// Prepare LAPACK workspaces.
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
// TODO(phawkins): preallocate workspace using XLA.
auto work = std::make_unique<ValueType[]>(work_size);
auto* work_data = work.get();
auto work_data = AllocateScratchMemory<dtype>(work_size);
const int64_t x_size{x_cols * x_cols};
auto x_copy = AllocateScratchMemory<dtype>(x_size);
auto work_eigvecs_left = AllocateScratchMemory<dtype>(x_size);
auto work_eigvecs_right = AllocateScratchMemory<dtype>(x_size);
const auto is_finite = [](ValueType* data, int64_t size) {
return absl::c_all_of(absl::MakeSpan(data, size),
[](ValueType value) { return std::isfinite(value); });
};
const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
std::copy_n(x_data, x_size, x_work_data);
if (is_finite(x_work_data, x_size)) {
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data,
&x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v,
info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
std::copy_n(x_data, x_size, x_copy.get());
if (is_finite(x_copy.get(), x_size)) {
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v,
eigvals_real_data, eigvals_imag_data, work_eigvecs_left.get(),
&x_cols_v, work_eigvecs_right.get(), &x_cols_v, work_data.get(),
&work_size_v, info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data,
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left.get(),
x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right.get(),
x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
if (info_data[0] == 0) {
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data,
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left.get(),
eigvecs_left_data);
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data,
eigvecs_right_data);
UnpackEigenvectors(x_cols_v, eigvals_imag_data,
work_eigvecs_right.get(), eigvecs_right_data);
}
} else {
info_data[0] = -4;
@ -1341,12 +1344,10 @@ ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
eig::ComputationMode compute_right, ffi::ResultBuffer<dtype> eigvals,
ffi::ResultBuffer<dtype> eigvecs_left,
ffi::ResultBuffer<dtype> eigvecs_right,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork) {
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
const auto* x_data = x.typed_data();
auto* x_work_data = x_work->typed_data();
auto* eigvecs_left_data = eigvecs_left->typed_data();
auto* eigvecs_right_data = eigvecs_right->typed_data();
auto* eigvals_data = eigvals->typed_data();
@ -1355,13 +1356,14 @@ ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
auto compute_left_v = static_cast<char>(compute_left);
auto compute_right_v = static_cast<char>(compute_right);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
// Prepare LAPACK workspaces.
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
// TODO(phawkins): preallocate workspace using XLA.
auto work = std::make_unique<ValueType[]>(work_size);
auto* work_data = work.get();
auto work_data = AllocateScratchMemory<dtype>(work_size);
const int64_t x_size{x_cols * x_cols};
auto x_copy = AllocateScratchMemory<dtype>(x_size);
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(2 * x_cols);
const auto is_finite = [](ValueType* data, int64_t size) {
return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) {
@ -1369,18 +1371,18 @@ ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
});
};
const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
std::copy_n(x_data, x_size, x_work_data);
if (is_finite(x_work_data, x_size)) {
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
std::copy_n(x_data, x_size, x_copy.get());
if (is_finite(x_copy.get(), x_size)) {
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v,
eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data,
&x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
&x_cols_v, work_data.get(), &work_size_v, rwork_data.get(),
info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes);
@ -1766,8 +1768,6 @@ template struct Sytrd<std::complex<double>>;
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigenvalues*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
.Attr<eig::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_HEEVD(name, data_type) \
@ -1780,9 +1780,6 @@ template struct Sytrd<std::complex<double>>;
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
/*eigenvalues*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
.Attr<eig::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_GEEV(name, data_type) \
@ -1798,12 +1795,7 @@ template struct Sytrd<std::complex<double>>;
/*eigvecs_left*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \
/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
/*work_eigvecs_left*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
/*work_eigvecs_right*/))
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
@ -1815,9 +1807,7 @@ template struct Sytrd<std::complex<double>>;
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_left*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/))
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
// FFI Handlers

View File

@ -395,12 +395,16 @@ struct ComplexHeevd {
namespace eig {
// Eigenvalue Decomposition
lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode);
lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode);
absl::StatusOr<lapack_int> GetWorkspaceSize(int64_t x_cols,
ComputationMode mode);
absl::StatusOr<lapack_int> GetIntWorkspaceSize(int64_t x_cols,
ComputationMode mode);
// Hermitian Eigenvalue Decomposition
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode);
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode);
absl::StatusOr<lapack_int> GetComplexWorkspaceSize(int64_t x_cols,
ComputationMode mode);
absl::StatusOr<lapack_int> GetRealWorkspaceSize(int64_t x_cols,
ComputationMode mode);
} // namespace eig
@ -417,13 +421,11 @@ struct EigenvalueDecompositionSymmetric {
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
static ::xla::ffi::Error Kernel(::xla::ffi::Buffer<dtype> x,
MatrixParams::UpLo uplo,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> eigenvalues,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> work,
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
eig::ComputationMode mode);
};
@ -445,9 +447,6 @@ struct EigenvalueDecompositionHermitian {
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> work,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork,
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
eig::ComputationMode mode);
};
@ -496,10 +495,7 @@ struct EigenvalueDecomposition {
::xla::ffi::ResultBuffer<dtype> eigvals_imag,
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left,
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> x_work,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right);
::xla::ffi::ResultBuffer<LapackIntDtype> info);
static int64_t GetWorkspaceSize(lapack_int x_cols,
eig::ComputationMode compute_left,
@ -526,9 +522,7 @@ struct EigenvalueDecompositionComplex {
::xla::ffi::ResultBuffer<dtype> eigvals,
::xla::ffi::ResultBuffer<dtype> eigvecs_left,
::xla::ffi::ResultBuffer<dtype> eigvecs_right,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> x_work,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork);
::xla::ffi::ResultBuffer<LapackIntDtype> info);
static int64_t GetWorkspaceSize(lapack_int x_cols,
eig::ComputationMode compute_left,