diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index f2a4d9430..4275d8e48 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -49,11 +49,7 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr # FFI Kernel LAPACK Workspace Size Queries -def heevd_rwork_size_ffi(n: int) -> int: ... -def heevd_work_size_ffi(n: int) -> int: ... def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def syevd_iwork_size_ffi(n: int) -> int: ... -def syevd_work_size_ffi(n: int) -> int: ... diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 83ed7610c..354a1cf9a 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -37,14 +37,6 @@ svd::ComputationMode GetSvdComputationMode(bool job_opt_compute_uv, return svd::ComputationMode::kComputeFullUVt; } -// Due to enforced kComputeEigenvectors, this assumes a larger workspace size. -// Could be improved to more accurately estimate the expected size based on the -// eig::ComputationMode value. -template -inline constexpr auto BoundWithEigvecs = +[](lapack_int n) { - return f(n, eig::ComputationMode::kComputeEigenvectors); -}; - void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; @@ -348,14 +340,6 @@ NB_MODULE(_lapack, m) { m.def("lapack_zungqr_workspace_ffi", &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("syevd_work_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("syevd_iwork_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("heevd_work_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("heevd_rwork_size_ffi", BoundWithEigvecs, - nb::arg("n")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index c1475d1f2..fd0a12ef2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -954,21 +954,24 @@ template struct ComplexHeevd>; // FFI Kernel -lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr eig::GetWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(2 * x_cols + 1); + return MaybeCastNoOverflow(2 * x_cols + 1); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(1 + 6 * x_cols + 2 * x_cols * x_cols); + return MaybeCastNoOverflow(1 + 6 * x_cols + + 2 * x_cols * x_cols); } } -lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr eig::GetIntWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: return 1; case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(3 + 5 * x_cols); + return MaybeCastNoOverflow(3 + 5 * x_cols); } } @@ -976,34 +979,34 @@ template ffi::Error EigenvalueDecompositionSymmetric::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, - ffi::ResultBuffer info, ffi::ResultBuffer work, - ffi::ResultBuffer iwork, eig::ComputationMode mode) { + ffi::ResultBuffer info, eig::ComputationMode mode) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); - auto* iwork_data = iwork->typed_data(); CopyIfDiffBuffer(x, x_out); auto mode_v = static_cast(mode); auto uplo_v = static_cast(uplo); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( - iwork->dimensions().back())); FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, MaybeCastNoOverflow(x_cols)); + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN(lapack_int work_size_v, + eig::GetWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v, + eig::GetIntWorkspaceSize(x_cols, mode)); + auto work_data = AllocateScratchMemory(work_size_v); + auto iwork_data = AllocateScratchMemory(iwork_size_v); const int64_t x_out_step{x_cols * x_cols}; const int64_t eigenvalues_step{x_cols}; for (int64_t i = 0; i < batch_count; ++i) { fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, - eigenvalues_data, work_data, &workspace_dim_v, iwork_data, - &iworkspace_dim_v, info_data); + eigenvalues_data, work_data.get(), &work_size_v, iwork_data.get(), + &iwork_size_v, info_data); x_out_data += x_out_step; eigenvalues_data += eigenvalues_step; ++info_data; @@ -1013,21 +1016,24 @@ ffi::Error EigenvalueDecompositionSymmetric::Kernel( namespace eig { -lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr GetComplexWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(x_cols + 1); + return MaybeCastNoOverflow(x_cols + 1); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(2 * x_cols + x_cols * x_cols); + return MaybeCastNoOverflow(2 * x_cols + x_cols * x_cols); } } -lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr GetRealWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(std::max(x_cols, int64_t{1})); + return MaybeCastNoOverflow(std::max(x_cols, int64_t{1})); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(1 + 5 * x_cols + 2 * x_cols * x_cols); + return MaybeCastNoOverflow(1 + 5 * x_cols + + 2 * x_cols * x_cols); } } @@ -1038,37 +1044,37 @@ ffi::Error EigenvalueDecompositionHermitian::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, - ffi::ResultBuffer info, ffi::ResultBuffer work, - ffi::ResultBuffer rwork, - ffi::ResultBuffer iwork, eig::ComputationMode mode) { + ffi::ResultBuffer info, eig::ComputationMode mode) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); - auto* iwork_data = iwork->typed_data(); CopyIfDiffBuffer(x, x_out); auto mode_v = static_cast(mode); auto uplo_v = static_cast(uplo); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow( - rwork->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( - iwork->dimensions().back())); FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, MaybeCastNoOverflow(x_cols)); + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN(lapack_int work_size_v, + eig::GetComplexWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int rwork_size_v, + eig::GetRealWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v, + eig::GetIntWorkspaceSize(x_cols, mode)); + auto work_data = AllocateScratchMemory(work_size_v); + auto iwork_data = AllocateScratchMemory(iwork_size_v); + auto rwork_data = AllocateScratchMemory(rwork_size_v); const int64_t x_out_step{x_cols * x_cols}; const int64_t eigenvalues_step{x_cols}; for (int64_t i = 0; i < batch_count; ++i) { fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, - eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(), - &rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data); + eigenvalues_data, work_data.get(), &work_size_v, rwork_data.get(), + &rwork_size_v, iwork_data.get(), &iwork_size_v, info_data); x_out_data += x_out_step; eigenvalues_data += eigenvalues_step; ++info_data; @@ -1265,16 +1271,11 @@ ffi::Error EigenvalueDecomposition::Kernel( ffi::ResultBuffer eigvals_imag, ffi::ResultBuffer eigvecs_left, ffi::ResultBuffer eigvecs_right, - ffi::ResultBuffer info, ffi::ResultBuffer x_work, - ffi::ResultBuffer work_eigvecs_left, - ffi::ResultBuffer work_eigvecs_right) { + ffi::ResultBuffer info) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); - auto* x_work_data = x_work->typed_data(); - auto* work_eigvecs_left_data = work_eigvecs_left->typed_data(); - auto* work_eigvecs_right_data = work_eigvecs_right->typed_data(); auto* eigvecs_left_data = eigvecs_left->typed_data(); auto* eigvecs_right_data = eigvecs_right->typed_data(); auto* eigvals_real_data = eigvals_real->typed_data(); @@ -1284,43 +1285,45 @@ ffi::Error EigenvalueDecomposition::Kernel( auto compute_left_v = static_cast(compute_left); auto compute_right_v = static_cast(compute_right); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - + // Prepare LAPACK workspaces. int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); FFI_ASSIGN_OR_RETURN(auto work_size_v, MaybeCastNoOverflow(work_size)); - // TODO(phawkins): preallocate workspace using XLA. - auto work = std::make_unique(work_size); - auto* work_data = work.get(); + auto work_data = AllocateScratchMemory(work_size); + const int64_t x_size{x_cols * x_cols}; + auto x_copy = AllocateScratchMemory(x_size); + auto work_eigvecs_left = AllocateScratchMemory(x_size); + auto work_eigvecs_right = AllocateScratchMemory(x_size); const auto is_finite = [](ValueType* data, int64_t size) { return absl::c_all_of(absl::MakeSpan(data, size), [](ValueType value) { return std::isfinite(value); }); }; - const int64_t x_size{x_cols * x_cols}; [[maybe_unused]] const auto x_size_bytes = static_cast(x_size) * sizeof(ValueType); [[maybe_unused]] const auto x_cols_bytes = static_cast(x_cols) * sizeof(ValueType); for (int64_t i = 0; i < batch_count; ++i) { - std::copy_n(x_data, x_size, x_work_data); - if (is_finite(x_work_data, x_size)) { - fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, - eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data, - &x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v, - info_data); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + std::copy_n(x_data, x_size, x_copy.get()); + if (is_finite(x_copy.get(), x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v, + eigvals_real_data, eigvals_imag_data, work_eigvecs_left.get(), + &x_cols_v, work_eigvecs_right.get(), &x_cols_v, work_data.get(), + &work_size_v, info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data, + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left.get(), + x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); if (info_data[0] == 0) { - UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data, + UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left.get(), eigvecs_left_data); - UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data, - eigvecs_right_data); + UnpackEigenvectors(x_cols_v, eigvals_imag_data, + work_eigvecs_right.get(), eigvecs_right_data); } } else { info_data[0] = -4; @@ -1341,12 +1344,10 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( eig::ComputationMode compute_right, ffi::ResultBuffer eigvals, ffi::ResultBuffer eigvecs_left, ffi::ResultBuffer eigvecs_right, - ffi::ResultBuffer info, ffi::ResultBuffer x_work, - ffi::ResultBuffer rwork) { + ffi::ResultBuffer info) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); - auto* x_work_data = x_work->typed_data(); auto* eigvecs_left_data = eigvecs_left->typed_data(); auto* eigvecs_right_data = eigvecs_right->typed_data(); auto* eigvals_data = eigvals->typed_data(); @@ -1355,13 +1356,14 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( auto compute_left_v = static_cast(compute_left); auto compute_right_v = static_cast(compute_right); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - + // Prepare LAPACK workspaces. int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); FFI_ASSIGN_OR_RETURN(auto work_size_v, MaybeCastNoOverflow(work_size)); - // TODO(phawkins): preallocate workspace using XLA. - auto work = std::make_unique(work_size); - auto* work_data = work.get(); + auto work_data = AllocateScratchMemory(work_size); + const int64_t x_size{x_cols * x_cols}; + auto x_copy = AllocateScratchMemory(x_size); + auto rwork_data = AllocateScratchMemory(2 * x_cols); const auto is_finite = [](ValueType* data, int64_t size) { return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { @@ -1369,18 +1371,18 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( }); }; - const int64_t x_size{x_cols * x_cols}; [[maybe_unused]] const auto x_size_bytes = static_cast(x_size) * sizeof(ValueType); [[maybe_unused]] const auto x_cols_bytes = static_cast(x_cols) * sizeof(ValueType); for (int64_t i = 0; i < batch_count; ++i) { - std::copy_n(x_data, x_size, x_work_data); - if (is_finite(x_work_data, x_size)) { - fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, + std::copy_n(x_data, x_size, x_copy.get()); + if (is_finite(x_copy.get(), x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v, eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data, - &x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + &x_cols_v, work_data.get(), &work_size_v, rwork_data.get(), + info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes); @@ -1766,23 +1768,18 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*x_out*/) \ .Ret<::xla::ffi::Buffer>(/*eigenvalues*/) \ .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ .Attr("mode")) -#define JAX_CPU_DEFINE_HEEVD(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, EigenvalueDecompositionHermitian::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Attr("uplo") \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*eigenvalues*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ +#define JAX_CPU_DEFINE_HEEVD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, EigenvalueDecompositionHermitian::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*eigenvalues*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) #define JAX_CPU_DEFINE_GEEV(name, data_type) \ @@ -1798,12 +1795,7 @@ template struct Sytrd>; /*eigvecs_left*/) \ .Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \ /*eigvecs_right*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*x_work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*work_eigvecs_left*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*work_eigvecs_right*/)) + .Ret<::xla::ffi::Buffer>(/*info*/)) #define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ @@ -1815,9 +1807,7 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*eigvals*/) \ .Ret<::xla::ffi::Buffer>(/*eigvecs_left*/) \ .Ret<::xla::ffi::Buffer>(/*eigvecs_right*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*x_work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/)) + .Ret<::xla::ffi::Buffer>(/*info*/)) // FFI Handlers diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 20823e785..4d021b688 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -395,12 +395,16 @@ struct ComplexHeevd { namespace eig { // Eigenvalue Decomposition -lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode); -lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode); +absl::StatusOr GetWorkspaceSize(int64_t x_cols, + ComputationMode mode); +absl::StatusOr GetIntWorkspaceSize(int64_t x_cols, + ComputationMode mode); // Hermitian Eigenvalue Decomposition -lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode); -lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode); +absl::StatusOr GetComplexWorkspaceSize(int64_t x_cols, + ComputationMode mode); +absl::StatusOr GetRealWorkspaceSize(int64_t x_cols, + ComputationMode mode); } // namespace eig @@ -417,14 +421,12 @@ struct EigenvalueDecompositionSymmetric { inline static FnType* fn = nullptr; - static ::xla::ffi::Error Kernel( - ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, - ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer eigenvalues, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work, - ::xla::ffi::ResultBuffer iwork, - eig::ComputationMode mode); + static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, + MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer eigenvalues, + ::xla::ffi::ResultBuffer info, + eig::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -445,9 +447,6 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues, ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork, - ::xla::ffi::ResultBuffer iwork, eig::ComputationMode mode); }; @@ -496,10 +495,7 @@ struct EigenvalueDecomposition { ::xla::ffi::ResultBuffer eigvals_imag, ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left, ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer x_work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right); + ::xla::ffi::ResultBuffer info); static int64_t GetWorkspaceSize(lapack_int x_cols, eig::ComputationMode compute_left, @@ -526,9 +522,7 @@ struct EigenvalueDecompositionComplex { ::xla::ffi::ResultBuffer eigvals, ::xla::ffi::ResultBuffer eigvecs_left, ::xla::ffi::ResultBuffer eigvecs_right, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer x_work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork); + ::xla::ffi::ResultBuffer info); static int64_t GetWorkspaceSize(lapack_int x_cols, eig::ComputationMode compute_left,