From b2a469b361b4e70ea23cddbe75753ae9984f7472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Mon, 5 Aug 2024 03:17:26 -0700 Subject: [PATCH] Port Eigenvalue Decompositions to XLA's FFI This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 659492696 --- jaxlib/cpu/BUILD | 2 + jaxlib/cpu/_lapack/__init__.pyi | 5 + jaxlib/cpu/_lapack/eig.pyi | 21 ++ jaxlib/cpu/cpu_kernels.cc | 8 + jaxlib/cpu/lapack.cc | 44 ++- jaxlib/cpu/lapack_kernels.cc | 413 +++++++++++++++++++++- jaxlib/cpu/lapack_kernels.h | 141 ++++++++ jaxlib/cpu/lapack_kernels_using_lapack.cc | 65 +++- 8 files changed, 676 insertions(+), 23 deletions(-) create mode 100644 jaxlib/cpu/_lapack/eig.pyi diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index a89d9ef95..d97a11e4f 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -41,6 +41,7 @@ cc_library( "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/types:span", ], ) @@ -64,6 +65,7 @@ pybind_extension( pytype_srcs = [ "_lapack/__init__.pyi", "_lapack/svd.pyi", + "_lapack/eig.pyi", ], deps = [ ":lapack_kernels", diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 32de9efd8..35c46fcee 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import eig as eig from . import svd as svd @@ -53,6 +54,8 @@ def cgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... def gesdd_iwork_size_ffi(m: int, n: int) -> int: ... def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... +def heevd_rwork_size_ffi(n: int) -> int: ... +def heevd_work_size_ffi(n: int) -> int: ... def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ... @@ -62,4 +65,6 @@ def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... +def syevd_iwork_size_ffi(n: int) -> int: ... +def syevd_work_size_ffi(n: int) -> int: ... def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... diff --git a/jaxlib/cpu/_lapack/eig.pyi b/jaxlib/cpu/_lapack/eig.pyi new file mode 100644 index 000000000..338c15402 --- /dev/null +++ b/jaxlib/cpu/_lapack/eig.pyi @@ -0,0 +1,21 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from typing import ClassVar + + +class ComputationMode(enum.Enum): + kComputeEigenvectors: ClassVar[ComputationMode] = ... + kNoEigenvectors: ClassVar[ComputationMode] = ... diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index d6136a202..93717ea9b 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -141,6 +141,14 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgesdd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgesdd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgesdd_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgesdd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_ssyevd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dsyevd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cheevd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zheevd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi); #undef JAX_CPU_REGISTER_HANDLER diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index bbc7307c0..3e59a4a02 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include +#include #include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" @@ -52,6 +52,14 @@ lapack_int GesddGetRealWorkspaceSize(lapack_int m, lapack_int n, return svd::GetRealWorkspaceSize(m, n, mode); } +// Due to enforced kComputeEigenvectors, this assumes a larger workspace size. +// Could be improved to more accurately estimate the expected size based on the +// eig::ComputationMode value. +template +inline constexpr auto BoundWithEigvecs = +[](lapack_int n) { + return f(n, eig::ComputationMode::kComputeEigenvectors); +}; + void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; @@ -128,11 +136,25 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dsyevd")); AssignKernelFn>>(lapack_ptr("cheevd")); AssignKernelFn>>(lapack_ptr("zheevd")); + AssignKernelFn>( + lapack_ptr("ssyevd")); + AssignKernelFn>( + lapack_ptr("dsyevd")); + AssignKernelFn>( + lapack_ptr("cheevd")); + AssignKernelFn>( + lapack_ptr("zheevd")); AssignKernelFn>(lapack_ptr("sgeev")); AssignKernelFn>(lapack_ptr("dgeev")); AssignKernelFn>>(lapack_ptr("cgeev")); AssignKernelFn>>(lapack_ptr("zgeev")); + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>( + lapack_ptr("cgeev")); + AssignKernelFn>( + lapack_ptr("zgeev")); AssignKernelFn>(lapack_ptr("sgees")); AssignKernelFn>(lapack_ptr("dgees")); @@ -246,6 +268,14 @@ nb::dict Registrations() { dict["lapack_dgesdd_ffi"] = EncapsulateFunction(lapack_dgesdd_ffi); dict["lapack_cgesdd_ffi"] = EncapsulateFunction(lapack_cgesdd_ffi); dict["lapack_zgesdd_ffi"] = EncapsulateFunction(lapack_zgesdd_ffi); + dict["lapack_ssyevd_ffi"] = EncapsulateFunction(lapack_ssyevd_ffi); + dict["lapack_dsyevd_ffi"] = EncapsulateFunction(lapack_dsyevd_ffi); + dict["lapack_cheevd_ffi"] = EncapsulateFunction(lapack_cheevd_ffi); + dict["lapack_zheevd_ffi"] = EncapsulateFunction(lapack_zheevd_ffi); + dict["lapack_sgeev_ffi"] = EncapsulateFunction(lapack_sgeev_ffi); + dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi); + dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi); + dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi); return dict; } @@ -256,12 +286,16 @@ NB_MODULE(_lapack, m) { m.def("registrations", &Registrations); // Submodules auto svd = m.def_submodule("svd"); + auto eig = m.def_submodule("eig"); // Enums nb::enum_(svd, "ComputationMode") // kComputeVtOverwriteXPartialU is not implemented .value("kComputeFullUVt", svd::ComputationMode::kComputeFullUVt) .value("kComputeMinUVt", svd::ComputationMode::kComputeMinUVt) .value("kNoComputeUVt", svd::ComputationMode::kNoComputeUVt); + nb::enum_(eig, "ComputationMode") + .value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors) + .value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors); // Old-style LAPACK Workspace Size Queries m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), @@ -353,6 +387,14 @@ NB_MODULE(_lapack, m) { nb::arg("m"), nb::arg("n"), nb::arg("mode")); m.def("zgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, nb::arg("m"), nb::arg("n"), nb::arg("mode")); + m.def("syevd_work_size_ffi", BoundWithEigvecs, + nb::arg("n")); + m.def("syevd_iwork_size_ffi", BoundWithEigvecs, + nb::arg("n")); + m.def("heevd_work_size_ffi", BoundWithEigvecs, + nb::arg("n")); + m.def("heevd_rwork_size_ffi", BoundWithEigvecs, + nb::arg("n")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index d1971dc2f..c3a32c481 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -21,12 +21,15 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" +#include "absl/types/span.h" #include "jaxlib/ffi_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" @@ -59,6 +62,7 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); #undef REGISTER_CHAR_ENUM_ATTR_DECODING @@ -942,27 +946,155 @@ template struct RealSyevd; template struct ComplexHeevd>; template struct ComplexHeevd>; +// FFI Kernel + +lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { + switch (mode) { + case ComputationMode::kNoEigenvectors: + return CastNoOverflow(2 * x_cols + 1); + case ComputationMode::kComputeEigenvectors: + return CastNoOverflow(1 + 6 * x_cols + 2 * x_cols * x_cols); + } +} + +lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) { + switch (mode) { + case ComputationMode::kNoEigenvectors: + return 1; + case ComputationMode::kComputeEigenvectors: + return CastNoOverflow(3 + 5 * x_cols); + } +} + +template +ffi::Error EigenvalueDecompositionSymmetric::Kernel( + ffi::Buffer x, MatrixParams::UpLo uplo, + ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer info, ffi::ResultBuffer work, + ffi::ResultBuffer iwork, eig::ComputationMode mode) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + auto* x_out_data = x_out->typed_data(); + auto* eigenvalues_data = eigenvalues->typed_data(); + auto* info_data = info->typed_data(); + auto* work_data = work->typed_data(); + auto* iwork_data = iwork->typed_data(); + + CopyIfDiffBuffer(x, x_out); + + auto mode_v = static_cast(mode); + auto uplo_v = static_cast(uplo); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( + work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( + iwork->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, + MaybeCastNoOverflow(x_cols)); + + const int64_t x_out_step{x_cols * x_cols}; + const int64_t eigenvalues_step{x_cols}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, + eigenvalues_data, work_data, &workspace_dim_v, iwork_data, + &iworkspace_dim_v, info_data); + x_out_data += x_out_step; + eigenvalues_data += eigenvalues_step; + ++info_data; + } + return ffi::Error::Success(); +} + +namespace eig { + +lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) { + switch (mode) { + case ComputationMode::kNoEigenvectors: + return CastNoOverflow(x_cols + 1); + case ComputationMode::kComputeEigenvectors: + return CastNoOverflow(2 * x_cols + x_cols * x_cols); + } +} + +lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) { + switch (mode) { + case ComputationMode::kNoEigenvectors: + return CastNoOverflow(std::max(x_cols, int64_t{1})); + case ComputationMode::kComputeEigenvectors: + return CastNoOverflow(1 + 5 * x_cols + 2 * x_cols * x_cols); + } +} + +} // namespace eig + +template +ffi::Error EigenvalueDecompositionHermitian::Kernel( + ffi::Buffer x, MatrixParams::UpLo uplo, + ffi::ResultBuffer x_out, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer info, ffi::ResultBuffer work, + ffi::ResultBuffer rwork, + ffi::ResultBuffer iwork, eig::ComputationMode mode) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + auto* x_out_data = x_out->typed_data(); + auto* eigenvalues_data = eigenvalues->typed_data(); + auto* info_data = info->typed_data(); + auto* work_data = work->typed_data(); + auto* iwork_data = iwork->typed_data(); + + CopyIfDiffBuffer(x, x_out); + + auto mode_v = static_cast(mode); + auto uplo_v = static_cast(uplo); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( + work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow( + rwork->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( + iwork->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, + MaybeCastNoOverflow(x_cols)); + + const int64_t x_out_step{x_cols * x_cols}; + const int64_t eigenvalues_step{x_cols}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, + eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(), + &rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data); + x_out_data += x_out_step; + eigenvalues_data += eigenvalues_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template struct EigenvalueDecompositionSymmetric; +template struct EigenvalueDecompositionSymmetric; +template struct EigenvalueDecompositionHermitian; +template struct EigenvalueDecompositionHermitian; + // LAPACK uses a packed representation to represent a mixture of real // eigenvectors and complex conjugate pairs. This helper unpacks the // representation into regular complex matrices. template -static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed, - std::complex* unpacked) { - T re, im; - int j; - j = 0; - while (j < n) { - if (im_eigenvalues[j] == 0. || std::isnan(im_eigenvalues[j])) { - for (int k = 0; k < n; ++k) { - unpacked[j * n + k] = {packed[j * n + k], 0.}; +static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag, + const T* packed, std::complex* unpacked) { + for (int j = 0; j < n;) { + if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { + // Real values in each row without imaginary part + // Second row of the imaginary part is not provided + for (int i = 0; i < n; ++i) { + unpacked[j * n + i] = {packed[j * n + i], 0.}; } ++j; } else { - for (int k = 0; k < n; ++k) { - re = packed[j * n + k]; - im = packed[(j + 1) * n + k]; - unpacked[j * n + k] = {re, im}; - unpacked[(j + 1) * n + k] = {re, -im}; + // Complex values where the real part is in the jth row + // and the imaginary part is in the next row (j + 1) + for (int i = 0; i < n; ++i) { + const T real_part = packed[j * n + i]; + const T imag_part = packed[(j + 1) * n + i]; + unpacked[j * n + i] = {real_part, imag_part}; + unpacked[(j + 1) * n + i] = {real_part, -imag_part}; } j += 2; } @@ -1116,6 +1248,183 @@ template struct RealGeev; template struct ComplexGeev>; template struct ComplexGeev>; +// FFI Kernel + +template +ffi::Error EigenvalueDecomposition::Kernel( + ffi::Buffer x, eig::ComputationMode compute_left, + eig::ComputationMode compute_right, ffi::ResultBuffer eigvals_real, + ffi::ResultBuffer eigvals_imag, + ffi::ResultBuffer eigvecs_left, + ffi::ResultBuffer eigvecs_right, + ffi::ResultBuffer info, ffi::ResultBuffer x_work, + ffi::ResultBuffer work_eigvecs_left, + ffi::ResultBuffer work_eigvecs_right) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + + const auto* x_data = x.typed_data(); + auto* x_work_data = x_work->typed_data(); + auto* work_eigvecs_left_data = work_eigvecs_left->typed_data(); + auto* work_eigvecs_right_data = work_eigvecs_right->typed_data(); + auto* eigvecs_left_data = eigvecs_left->typed_data(); + auto* eigvecs_right_data = eigvecs_right->typed_data(); + auto* eigvals_real_data = eigvals_real->typed_data(); + auto* eigvals_imag_data = eigvals_imag->typed_data(); + auto* info_data = info->typed_data(); + + auto compute_left_v = static_cast(compute_left); + auto compute_right_v = static_cast(compute_right); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + + int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + // TODO(phawkins): preallocate workspace using XLA. + auto work = std::make_unique(work_size); + auto* work_data = work.get(); + + const auto is_finite = [](ValueType* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), + [](ValueType value) { return std::isfinite(value); }); + }; + + const int64_t x_size{x_cols * x_cols}; + [[maybe_unused]] const auto x_size_bytes = + static_cast(x_size) * sizeof(ValueType); + [[maybe_unused]] const auto x_cols_bytes = + static_cast(x_cols) * sizeof(ValueType); + for (int64_t i = 0; i < batch_count; ++i) { + std::copy_n(x_data, x_size, x_work_data); + if (is_finite(x_work_data, x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, + eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data, + &x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v, + info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data, + x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); + if (info_data[0] == 0) { + UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data, + eigvecs_left_data); + UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data, + eigvecs_right_data); + } + } else { + info_data[0] = -4; + } + x_data += x_size; + eigvals_real_data += x_cols; + eigvals_imag_data += x_cols; + eigvecs_left_data += x_size; + eigvecs_right_data += x_size; + ++info_data; + } + return ffi::Error::Success(); +} + +template +ffi::Error EigenvalueDecompositionComplex::Kernel( + ffi::Buffer x, eig::ComputationMode compute_left, + eig::ComputationMode compute_right, ffi::ResultBuffer eigvals, + ffi::ResultBuffer eigvecs_left, + ffi::ResultBuffer eigvecs_right, + ffi::ResultBuffer info, ffi::ResultBuffer x_work, + ffi::ResultBuffer rwork) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + const auto* x_data = x.typed_data(); + auto* x_work_data = x_work->typed_data(); + auto* eigvecs_left_data = eigvecs_left->typed_data(); + auto* eigvecs_right_data = eigvecs_right->typed_data(); + auto* eigvals_data = eigvals->typed_data(); + auto* info_data = info->typed_data(); + + auto compute_left_v = static_cast(compute_left); + auto compute_right_v = static_cast(compute_right); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + + int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + // TODO(phawkins): preallocate workspace using XLA. + auto work = std::make_unique(work_size); + auto* work_data = work.get(); + + const auto is_finite = [](ValueType* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { + return std::isfinite(z.real()) && std::isfinite(z.imag()); + }); + }; + + const int64_t x_size{x_cols * x_cols}; + [[maybe_unused]] const auto x_size_bytes = + static_cast(x_size) * sizeof(ValueType); + [[maybe_unused]] const auto x_cols_bytes = + static_cast(x_cols) * sizeof(ValueType); + for (int64_t i = 0; i < batch_count; ++i) { + std::copy_n(x_data, x_size, x_work_data); + if (is_finite(x_work_data, x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, + eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data, + &x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); + } else { + info_data[0] = -4; + } + x_data += x_size; + eigvals_data += x_cols; + eigvecs_left_data += x_size; + eigvecs_right_data += x_size; + ++info_data; + } + return ffi::Error::Success(); +} + +template +int64_t EigenvalueDecomposition::GetWorkspaceSize( + lapack_int x_cols, eig::ComputationMode compute_left, + eig::ComputationMode compute_right) { + ValueType optimal_size = {}; + lapack_int workspace_query = -1; + lapack_int info = 0; + + auto compute_left_v = static_cast(compute_left); + auto compute_right_v = static_cast(compute_right); + fn(&compute_left_v, &compute_right_v, &x_cols, nullptr, &x_cols, nullptr, + nullptr, nullptr, &x_cols, nullptr, &x_cols, &optimal_size, + &workspace_query, &info); + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +}; + +template +int64_t EigenvalueDecompositionComplex::GetWorkspaceSize( + lapack_int x_cols, eig::ComputationMode compute_left, + eig::ComputationMode compute_right) { + ValueType optimal_size = {}; + lapack_int workspace_query = -1; + lapack_int info = 0; + // NULL rwork crashes, LAPACK unnecessarily writes x_cols into rwork + RealType rwork[1]; + auto compute_left_v = static_cast(compute_left); + auto compute_right_v = static_cast(compute_right); + fn(&compute_left_v, &compute_right_v, &x_cols, nullptr, &x_cols, nullptr, + nullptr, &x_cols, nullptr, &x_cols, &optimal_size, &workspace_query, rwork, + &info); + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +}; + +template struct EigenvalueDecomposition; +template struct EigenvalueDecomposition; +template struct EigenvalueDecompositionComplex; +template struct EigenvalueDecompositionComplex; + //== Schur Decomposition ==// // lapack gees @@ -1445,6 +1754,68 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*work*/) \ .Attr("mode")) +#define JAX_CPU_DEFINE_SYEVD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, EigenvalueDecompositionSymmetric::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*eigenvalues*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer>(/*work*/) \ + .Ret<::xla::ffi::Buffer>(/*iwork*/) \ + .Attr("mode")) + +#define JAX_CPU_DEFINE_HEEVD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, EigenvalueDecompositionHermitian::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*eigenvalues*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer>(/*work*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ + .Ret<::xla::ffi::Buffer>(/*iwork*/) \ + .Attr("mode")) + +#define JAX_CPU_DEFINE_GEEV(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, EigenvalueDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("compute_left") \ + .Attr("compute_right") \ + .Ret<::xla::ffi::Buffer>(/*eigvals_real*/) \ + .Ret<::xla::ffi::Buffer>(/*eigvals_imag*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \ + /*eigvecs_left*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \ + /*eigvecs_right*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer>(/*x_work*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*work_eigvecs_left*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*work_eigvecs_right*/)) + +#define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, EigenvalueDecompositionComplex::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("compute_left") \ + .Attr("compute_right") \ + .Ret<::xla::ffi::Buffer>(/*eigvals*/) \ + .Ret<::xla::ffi::Buffer>(/*eigvecs_left*/) \ + .Ret<::xla::ffi::Buffer>(/*eigvecs_right*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ + .Ret<::xla::ffi::Buffer>(/*x_work*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/)) + // FFI Handlers JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32); @@ -1477,6 +1848,16 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128); +JAX_CPU_DEFINE_SYEVD(lapack_ssyevd_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_SYEVD(lapack_dsyevd_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_HEEVD(lapack_cheevd_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_HEEVD(lapack_zheevd_ffi, ::xla::ffi::DataType::C128); + +JAX_CPU_DEFINE_GEEV(lapack_sgeev_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128); + #undef JAX_CPU_DEFINE_TRSM #undef JAX_CPU_DEFINE_GETRF #undef JAX_CPU_DEFINE_GEQRF @@ -1484,5 +1865,9 @@ JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128); #undef JAX_CPU_DEFINE_POTRF #undef JAX_CPU_DEFINE_GESDD #undef JAX_CPU_DEFINE_GESDD_COMPLEX +#undef JAX_CPU_DEFINE_SYEVD +#undef JAX_CPU_DEFINE_HEEVD +#undef JAX_CPU_DEFINE_GEEV +#undef JAX_CPU_DEFINE_GEEV_COMPLEX } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 811a5d4b3..5493ec8cb 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -59,6 +59,15 @@ inline bool ComputesUV(ComputationMode mode) { } // namespace svd +namespace eig { + +enum class ComputationMode : char { + kNoEigenvectors = 'N', + kComputeEigenvectors = 'V', +}; + +} + template void AssignKernelFn(void* func) { KernelType::fn = reinterpret_cast(func); @@ -85,6 +94,7 @@ DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); #undef DEFINE_CHAR_ENUM_ATTR_DECODING @@ -383,6 +393,67 @@ struct ComplexHeevd { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// FFI Kernel + +namespace eig { + +// Eigenvalue Decomposition +lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode); +lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode); + +// Hermitian Eigenvalue Decomposition +lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode); +lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode); + +} // namespace eig + +template <::xla::ffi::DataType dtype> +struct EigenvalueDecompositionSymmetric { + static_assert(!::xla::ffi::IsComplexType(), + "There exists a separate implementation for Complex types"); + + using ValueType = ::xla::ffi::NativeType; + using FnType = void(char* jobz, char* uplo, lapack_int* n, ValueType* a, + lapack_int* lda, ValueType* w, ValueType* work, + lapack_int* lwork, lapack_int* iwork, lapack_int* liwork, + lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer eigenvalues, + ::xla::ffi::ResultBuffer info, + ::xla::ffi::ResultBuffer work, + ::xla::ffi::ResultBuffer iwork, + eig::ComputationMode mode); +}; + +template <::xla::ffi::DataType dtype> +struct EigenvalueDecompositionHermitian { + static_assert(::xla::ffi::IsComplexType()); + + using ValueType = ::xla::ffi::NativeType; + using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>; + using FnType = void(char* jobz, char* uplo, lapack_int* n, ValueType* a, + lapack_int* lda, RealType* w, ValueType* work, + lapack_int* lwork, RealType* rwork, lapack_int* lrwork, + lapack_int* iwork, lapack_int* liwork, lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues, + ::xla::ffi::ResultBuffer info, + ::xla::ffi::ResultBuffer work, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork, + ::xla::ffi::ResultBuffer iwork, + eig::ComputationMode mode); +}; + // lapack geev template @@ -405,6 +476,68 @@ struct ComplexGeev { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct EigenvalueDecomposition { + static_assert(!::xla::ffi::IsComplexType(), + "There exists a separate implementation for Complex types"); + + using ValueType = ::xla::ffi::NativeType; + using FnType = void(char* jobvl, char* jobvr, lapack_int* n, ValueType* a, + lapack_int* lda, ValueType* wr, ValueType* wi, + ValueType* vl, lapack_int* ldvl, ValueType* vr, + lapack_int* ldvr, ValueType* work, lapack_int* lwork, + lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, eig::ComputationMode compute_left, + eig::ComputationMode compute_right, + ::xla::ffi::ResultBuffer eigvals_real, + ::xla::ffi::ResultBuffer eigvals_imag, + ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left, + ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right, + ::xla::ffi::ResultBuffer info, + ::xla::ffi::ResultBuffer x_work, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right); + + static int64_t GetWorkspaceSize(lapack_int x_cols, + eig::ComputationMode compute_left, + eig::ComputationMode compute_right); +}; + +template <::xla::ffi::DataType dtype> +struct EigenvalueDecompositionComplex { + static_assert(::xla::ffi::IsComplexType()); + + using ValueType = ::xla::ffi::NativeType; + using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>; + using FnType = void(char* jobvl, char* jobvr, lapack_int* n, ValueType* a, + lapack_int* lda, ValueType* w, ValueType* vl, + lapack_int* ldvl, ValueType* vr, lapack_int* ldvr, + ValueType* work, lapack_int* lwork, RealType* rwork, + lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, eig::ComputationMode compute_left, + eig::ComputationMode compute_right, + ::xla::ffi::ResultBuffer eigvals, + ::xla::ffi::ResultBuffer eigvecs_left, + ::xla::ffi::ResultBuffer eigvecs_right, + ::xla::ffi::ResultBuffer info, + ::xla::ffi::ResultBuffer x_work, + ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork); + + static int64_t GetWorkspaceSize(lapack_int x_cols, + eig::ComputationMode compute_left, + eig::ComputationMode compute_right); +}; + //== Schur Decomposition ==// // lapack gees @@ -500,6 +633,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssyevd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsyevd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cheevd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zheevd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi); } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index d17925e13..2a2597629 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -56,15 +56,15 @@ jax::SingularValueDecomposition::FnType dgesdd_; jax::SingularValueDecompositionComplex::FnType cgesdd_; jax::SingularValueDecompositionComplex::FnType zgesdd_; -jax::RealSyevd::FnType ssyevd_; -jax::RealSyevd::FnType dsyevd_; -jax::ComplexHeevd>::FnType cheevd_; -jax::ComplexHeevd>::FnType zheevd_; +jax::EigenvalueDecompositionSymmetric::FnType ssyevd_; +jax::EigenvalueDecompositionSymmetric::FnType dsyevd_; +jax::EigenvalueDecompositionHermitian::FnType cheevd_; +jax::EigenvalueDecompositionHermitian::FnType zheevd_; -jax::RealGeev::FnType sgeev_; -jax::RealGeev::FnType dgeev_; -jax::ComplexGeev>::FnType cgeev_; -jax::ComplexGeev>::FnType zgeev_; +jax::EigenvalueDecomposition::FnType sgeev_; +jax::EigenvalueDecomposition::FnType dgeev_; +jax::EigenvalueDecompositionComplex::FnType cgeev_; +jax::EigenvalueDecompositionComplex::FnType zgeev_; jax::RealGees::FnType sgees_; jax::RealGees::FnType dgees_; @@ -173,6 +173,44 @@ static_assert( jax::SingularValueDecompositionComplex::FnType, jax::ComplexGesdd>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v< + jax::EigenvalueDecompositionSymmetric::FnType, + jax::RealSyevd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v< + jax::EigenvalueDecompositionSymmetric::FnType, + jax::RealSyevd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v< + jax::EigenvalueDecompositionHermitian::FnType, + jax::ComplexHeevd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v< + jax::EigenvalueDecompositionHermitian::FnType, + jax::ComplexHeevd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::RealGeev::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::RealGeev::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v< + jax::EigenvalueDecompositionComplex::FnType, + jax::ComplexGeev>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v< + jax::EigenvalueDecompositionComplex::FnType, + jax::ComplexGeev>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); #undef JAX_KERNEL_FNTYPE_MISMATCH_MSG @@ -266,6 +304,17 @@ static auto init = []() -> int { AssignKernelFn>( zgesdd_); + AssignKernelFn>(ssyevd_); + AssignKernelFn>(dsyevd_); + AssignKernelFn>(cheevd_); + AssignKernelFn>( + zheevd_); + + AssignKernelFn>(sgeev_); + AssignKernelFn>(dgeev_); + AssignKernelFn>(cgeev_); + AssignKernelFn>(zgeev_); + return 0; }();