Port Eigenvalue Decompositions to XLA's FFI

This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 659492696
This commit is contained in:
Paweł Paruzel 2024-08-05 03:17:26 -07:00 committed by jax authors
parent 9b35b760ce
commit b2a469b361
8 changed files with 676 additions and 23 deletions

View File

@ -41,6 +41,7 @@ cc_library(
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/types:span",
],
)
@ -64,6 +65,7 @@ pybind_extension(
pytype_srcs = [
"_lapack/__init__.pyi",
"_lapack/svd.pyi",
"_lapack/eig.pyi",
],
deps = [
":lapack_kernels",

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import eig as eig
from . import svd as svd
@ -53,6 +54,8 @@ def cgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def gesdd_iwork_size_ffi(m: int, n: int) -> int: ...
def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def heevd_rwork_size_ffi(n: int) -> int: ...
def heevd_work_size_ffi(n: int) -> int: ...
def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ...
@ -62,4 +65,6 @@ def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def syevd_iwork_size_ffi(n: int) -> int: ...
def syevd_work_size_ffi(n: int) -> int: ...
def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...

View File

@ -0,0 +1,21 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from typing import ClassVar
class ComputationMode(enum.Enum):
kComputeEigenvectors: ClassVar[ComputationMode] = ...
kNoEigenvectors: ClassVar[ComputationMode] = ...

View File

@ -141,6 +141,14 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgesdd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_ssyevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dsyevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cheevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zheevd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi);
#undef JAX_CPU_REGISTER_HANDLER

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <complex>
#include <cstdint>
#include "nanobind/nanobind.h"
#include "jaxlib/cpu/lapack_kernels.h"
@ -52,6 +52,14 @@ lapack_int GesddGetRealWorkspaceSize(lapack_int m, lapack_int n,
return svd::GetRealWorkspaceSize(m, n, mode);
}
// Due to enforced kComputeEigenvectors, this assumes a larger workspace size.
// Could be improved to more accurately estimate the expected size based on the
// eig::ComputationMode value.
template <lapack_int (&f)(int64_t, eig::ComputationMode)>
inline constexpr auto BoundWithEigvecs = +[](lapack_int n) {
return f(n, eig::ComputationMode::kComputeEigenvectors);
};
void GetLapackKernelsFromScipy() {
static bool initialized = false; // Protected by GIL
if (initialized) return;
@ -128,11 +136,25 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<RealSyevd<double>>(lapack_ptr("dsyevd"));
AssignKernelFn<ComplexHeevd<std::complex<float>>>(lapack_ptr("cheevd"));
AssignKernelFn<ComplexHeevd<std::complex<double>>>(lapack_ptr("zheevd"));
AssignKernelFn<EigenvalueDecompositionSymmetric<DataType::F32>>(
lapack_ptr("ssyevd"));
AssignKernelFn<EigenvalueDecompositionSymmetric<DataType::F64>>(
lapack_ptr("dsyevd"));
AssignKernelFn<EigenvalueDecompositionHermitian<DataType::C64>>(
lapack_ptr("cheevd"));
AssignKernelFn<EigenvalueDecompositionHermitian<DataType::C128>>(
lapack_ptr("zheevd"));
AssignKernelFn<RealGeev<float>>(lapack_ptr("sgeev"));
AssignKernelFn<RealGeev<double>>(lapack_ptr("dgeev"));
AssignKernelFn<ComplexGeev<std::complex<float>>>(lapack_ptr("cgeev"));
AssignKernelFn<ComplexGeev<std::complex<double>>>(lapack_ptr("zgeev"));
AssignKernelFn<EigenvalueDecomposition<DataType::F32>>(lapack_ptr("sgeev"));
AssignKernelFn<EigenvalueDecomposition<DataType::F64>>(lapack_ptr("dgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<DataType::C64>>(
lapack_ptr("cgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<DataType::C128>>(
lapack_ptr("zgeev"));
AssignKernelFn<RealGees<float>>(lapack_ptr("sgees"));
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
@ -246,6 +268,14 @@ nb::dict Registrations() {
dict["lapack_dgesdd_ffi"] = EncapsulateFunction(lapack_dgesdd_ffi);
dict["lapack_cgesdd_ffi"] = EncapsulateFunction(lapack_cgesdd_ffi);
dict["lapack_zgesdd_ffi"] = EncapsulateFunction(lapack_zgesdd_ffi);
dict["lapack_ssyevd_ffi"] = EncapsulateFunction(lapack_ssyevd_ffi);
dict["lapack_dsyevd_ffi"] = EncapsulateFunction(lapack_dsyevd_ffi);
dict["lapack_cheevd_ffi"] = EncapsulateFunction(lapack_cheevd_ffi);
dict["lapack_zheevd_ffi"] = EncapsulateFunction(lapack_zheevd_ffi);
dict["lapack_sgeev_ffi"] = EncapsulateFunction(lapack_sgeev_ffi);
dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi);
dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi);
dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi);
return dict;
}
@ -256,12 +286,16 @@ NB_MODULE(_lapack, m) {
m.def("registrations", &Registrations);
// Submodules
auto svd = m.def_submodule("svd");
auto eig = m.def_submodule("eig");
// Enums
nb::enum_<svd::ComputationMode>(svd, "ComputationMode")
// kComputeVtOverwriteXPartialU is not implemented
.value("kComputeFullUVt", svd::ComputationMode::kComputeFullUVt)
.value("kComputeMinUVt", svd::ComputationMode::kComputeMinUVt)
.value("kNoComputeUVt", svd::ComputationMode::kNoComputeUVt);
nb::enum_<eig::ComputationMode>(eig, "ComputationMode")
.value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors)
.value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors);
// Old-style LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
@ -353,6 +387,14 @@ NB_MODULE(_lapack, m) {
nb::arg("m"), nb::arg("n"), nb::arg("mode"));
m.def("zgesdd_work_size_ffi", &svd::SVDType<DataType::C128>::GetWorkspaceSize,
nb::arg("m"), nb::arg("n"), nb::arg("mode"));
m.def("syevd_work_size_ffi", BoundWithEigvecs<eig::GetWorkspaceSize>,
nb::arg("n"));
m.def("syevd_iwork_size_ffi", BoundWithEigvecs<eig::GetIntWorkspaceSize>,
nb::arg("n"));
m.def("heevd_work_size_ffi", BoundWithEigvecs<eig::GetComplexWorkspaceSize>,
nb::arg("n"));
m.def("heevd_rwork_size_ffi", BoundWithEigvecs<eig::GetRealWorkspaceSize>,
nb::arg("n"));
}
} // namespace

View File

@ -21,12 +21,15 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/types/span.h"
#include "jaxlib/ffi_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
@ -59,6 +62,7 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
#undef REGISTER_CHAR_ENUM_ATTR_DECODING
@ -942,27 +946,155 @@ template struct RealSyevd<double>;
template struct ComplexHeevd<std::complex<float>>;
template struct ComplexHeevd<std::complex<double>>;
// FFI Kernel
lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return CastNoOverflow<lapack_int>(2 * x_cols + 1);
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(1 + 6 * x_cols + 2 * x_cols * x_cols);
}
}
lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return 1;
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(3 + 5 * x_cols);
}
}
template <ffi::DataType dtype>
ffi::Error EigenvalueDecompositionSymmetric<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> eigenvalues,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
auto* x_out_data = x_out->typed_data();
auto* eigenvalues_data = eigenvalues->typed_data();
auto* info_data = info->typed_data();
auto* work_data = work->typed_data();
auto* iwork_data = iwork->typed_data();
CopyIfDiffBuffer(x, x_out);
auto mode_v = static_cast<char>(mode);
auto uplo_v = static_cast<char>(uplo);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
work->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
iwork->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
const int64_t x_out_step{x_cols * x_cols};
const int64_t eigenvalues_step{x_cols};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
eigenvalues_data, work_data, &workspace_dim_v, iwork_data,
&iworkspace_dim_v, info_data);
x_out_data += x_out_step;
eigenvalues_data += eigenvalues_step;
++info_data;
}
return ffi::Error::Success();
}
namespace eig {
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return CastNoOverflow<lapack_int>(x_cols + 1);
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(2 * x_cols + x_cols * x_cols);
}
}
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) {
switch (mode) {
case ComputationMode::kNoEigenvectors:
return CastNoOverflow<lapack_int>(std::max(x_cols, int64_t{1}));
case ComputationMode::kComputeEigenvectors:
return CastNoOverflow<lapack_int>(1 + 5 * x_cols + 2 * x_cols * x_cols);
}
}
} // namespace eig
template <ffi::DataType dtype>
ffi::Error EigenvalueDecompositionHermitian<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<ffi::ToReal(dtype)> eigenvalues,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork,
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
auto* x_out_data = x_out->typed_data();
auto* eigenvalues_data = eigenvalues->typed_data();
auto* info_data = info->typed_data();
auto* work_data = work->typed_data();
auto* iwork_data = iwork->typed_data();
CopyIfDiffBuffer(x, x_out);
auto mode_v = static_cast<char>(mode);
auto uplo_v = static_cast<char>(uplo);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
work->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
rwork->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
iwork->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
const int64_t x_out_step{x_cols * x_cols};
const int64_t eigenvalues_step{x_cols};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(),
&rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data);
x_out_data += x_out_step;
eigenvalues_data += eigenvalues_step;
++info_data;
}
return ffi::Error::Success();
}
template struct EigenvalueDecompositionSymmetric<ffi::DataType::F32>;
template struct EigenvalueDecompositionSymmetric<ffi::DataType::F64>;
template struct EigenvalueDecompositionHermitian<ffi::DataType::C64>;
template struct EigenvalueDecompositionHermitian<ffi::DataType::C128>;
// LAPACK uses a packed representation to represent a mixture of real
// eigenvectors and complex conjugate pairs. This helper unpacks the
// representation into regular complex matrices.
template <typename T>
static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed,
std::complex<T>* unpacked) {
T re, im;
int j;
j = 0;
while (j < n) {
if (im_eigenvalues[j] == 0. || std::isnan(im_eigenvalues[j])) {
for (int k = 0; k < n; ++k) {
unpacked[j * n + k] = {packed[j * n + k], 0.};
static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag,
const T* packed, std::complex<T>* unpacked) {
for (int j = 0; j < n;) {
if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
// Real values in each row without imaginary part
// Second row of the imaginary part is not provided
for (int i = 0; i < n; ++i) {
unpacked[j * n + i] = {packed[j * n + i], 0.};
}
++j;
} else {
for (int k = 0; k < n; ++k) {
re = packed[j * n + k];
im = packed[(j + 1) * n + k];
unpacked[j * n + k] = {re, im};
unpacked[(j + 1) * n + k] = {re, -im};
// Complex values where the real part is in the jth row
// and the imaginary part is in the next row (j + 1)
for (int i = 0; i < n; ++i) {
const T real_part = packed[j * n + i];
const T imag_part = packed[(j + 1) * n + i];
unpacked[j * n + i] = {real_part, imag_part};
unpacked[(j + 1) * n + i] = {real_part, -imag_part};
}
j += 2;
}
@ -1116,6 +1248,183 @@ template struct RealGeev<double>;
template struct ComplexGeev<std::complex<float>>;
template struct ComplexGeev<std::complex<double>>;
// FFI Kernel
template <ffi::DataType dtype>
ffi::Error EigenvalueDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
eig::ComputationMode compute_right, ffi::ResultBuffer<dtype> eigvals_real,
ffi::ResultBuffer<dtype> eigvals_imag,
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_left,
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_right,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_left,
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_right) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
const auto* x_data = x.typed_data();
auto* x_work_data = x_work->typed_data();
auto* work_eigvecs_left_data = work_eigvecs_left->typed_data();
auto* work_eigvecs_right_data = work_eigvecs_right->typed_data();
auto* eigvecs_left_data = eigvecs_left->typed_data();
auto* eigvecs_right_data = eigvecs_right->typed_data();
auto* eigvals_real_data = eigvals_real->typed_data();
auto* eigvals_imag_data = eigvals_imag->typed_data();
auto* info_data = info->typed_data();
auto compute_left_v = static_cast<char>(compute_left);
auto compute_right_v = static_cast<char>(compute_right);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
// TODO(phawkins): preallocate workspace using XLA.
auto work = std::make_unique<ValueType[]>(work_size);
auto* work_data = work.get();
const auto is_finite = [](ValueType* data, int64_t size) {
return absl::c_all_of(absl::MakeSpan(data, size),
[](ValueType value) { return std::isfinite(value); });
};
const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
std::copy_n(x_data, x_size, x_work_data);
if (is_finite(x_work_data, x_size)) {
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data,
&x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v,
info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data,
x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
if (info_data[0] == 0) {
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data,
eigvecs_left_data);
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data,
eigvecs_right_data);
}
} else {
info_data[0] = -4;
}
x_data += x_size;
eigvals_real_data += x_cols;
eigvals_imag_data += x_cols;
eigvecs_left_data += x_size;
eigvecs_right_data += x_size;
++info_data;
}
return ffi::Error::Success();
}
template <ffi::DataType dtype>
ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
eig::ComputationMode compute_right, ffi::ResultBuffer<dtype> eigvals,
ffi::ResultBuffer<dtype> eigvecs_left,
ffi::ResultBuffer<dtype> eigvecs_right,
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
const auto* x_data = x.typed_data();
auto* x_work_data = x_work->typed_data();
auto* eigvecs_left_data = eigvecs_left->typed_data();
auto* eigvecs_right_data = eigvecs_right->typed_data();
auto* eigvals_data = eigvals->typed_data();
auto* info_data = info->typed_data();
auto compute_left_v = static_cast<char>(compute_left);
auto compute_right_v = static_cast<char>(compute_right);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
// TODO(phawkins): preallocate workspace using XLA.
auto work = std::make_unique<ValueType[]>(work_size);
auto* work_data = work.get();
const auto is_finite = [](ValueType* data, int64_t size) {
return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) {
return std::isfinite(z.real()) && std::isfinite(z.imag());
});
};
const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
std::copy_n(x_data, x_size, x_work_data);
if (is_finite(x_work_data, x_size)) {
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data,
&x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
} else {
info_data[0] = -4;
}
x_data += x_size;
eigvals_data += x_cols;
eigvecs_left_data += x_size;
eigvecs_right_data += x_size;
++info_data;
}
return ffi::Error::Success();
}
template <ffi::DataType dtype>
int64_t EigenvalueDecomposition<dtype>::GetWorkspaceSize(
lapack_int x_cols, eig::ComputationMode compute_left,
eig::ComputationMode compute_right) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;
auto compute_left_v = static_cast<char>(compute_left);
auto compute_right_v = static_cast<char>(compute_right);
fn(&compute_left_v, &compute_right_v, &x_cols, nullptr, &x_cols, nullptr,
nullptr, nullptr, &x_cols, nullptr, &x_cols, &optimal_size,
&workspace_query, &info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};
template <ffi::DataType dtype>
int64_t EigenvalueDecompositionComplex<dtype>::GetWorkspaceSize(
lapack_int x_cols, eig::ComputationMode compute_left,
eig::ComputationMode compute_right) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;
// NULL rwork crashes, LAPACK unnecessarily writes x_cols into rwork
RealType rwork[1];
auto compute_left_v = static_cast<char>(compute_left);
auto compute_right_v = static_cast<char>(compute_right);
fn(&compute_left_v, &compute_right_v, &x_cols, nullptr, &x_cols, nullptr,
nullptr, &x_cols, nullptr, &x_cols, &optimal_size, &workspace_query, rwork,
&info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};
template struct EigenvalueDecomposition<ffi::DataType::F32>;
template struct EigenvalueDecomposition<ffi::DataType::F64>;
template struct EigenvalueDecompositionComplex<ffi::DataType::C64>;
template struct EigenvalueDecompositionComplex<ffi::DataType::C128>;
//== Schur Decomposition ==//
// lapack gees
@ -1445,6 +1754,68 @@ template struct Sytrd<std::complex<double>>;
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
.Attr<svd::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_SYEVD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, EigenvalueDecompositionSymmetric<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<MatrixParams::UpLo>("uplo") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigenvalues*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
.Attr<eig::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_HEEVD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, EigenvalueDecompositionHermitian<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<MatrixParams::UpLo>("uplo") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
/*eigenvalues*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
.Attr<eig::ComputationMode>("mode"))
#define JAX_CPU_DEFINE_GEEV(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, EigenvalueDecomposition<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<eig::ComputationMode>("compute_left") \
.Attr<eig::ComputationMode>("compute_right") \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_real*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_imag*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \
/*eigvecs_left*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \
/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
/*work_eigvecs_left*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
/*work_eigvecs_right*/))
#define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, EigenvalueDecompositionComplex<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<eig::ComputationMode>("compute_left") \
.Attr<eig::ComputationMode>("compute_right") \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_left*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/))
// FFI Handlers
JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
@ -1477,6 +1848,16 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_SYEVD(lapack_ssyevd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_SYEVD(lapack_dsyevd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_HEEVD(lapack_cheevd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_HEEVD(lapack_zheevd_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GEEV(lapack_sgeev_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_TRSM
#undef JAX_CPU_DEFINE_GETRF
#undef JAX_CPU_DEFINE_GEQRF
@ -1484,5 +1865,9 @@ JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_POTRF
#undef JAX_CPU_DEFINE_GESDD
#undef JAX_CPU_DEFINE_GESDD_COMPLEX
#undef JAX_CPU_DEFINE_SYEVD
#undef JAX_CPU_DEFINE_HEEVD
#undef JAX_CPU_DEFINE_GEEV
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
} // namespace jax

View File

@ -59,6 +59,15 @@ inline bool ComputesUV(ComputationMode mode) {
} // namespace svd
namespace eig {
enum class ComputationMode : char {
kNoEigenvectors = 'N',
kComputeEigenvectors = 'V',
};
}
template <typename KernelType>
void AssignKernelFn(void* func) {
KernelType::fn = reinterpret_cast<typename KernelType::FnType*>(func);
@ -85,6 +94,7 @@ DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
#undef DEFINE_CHAR_ENUM_ATTR_DECODING
@ -383,6 +393,67 @@ struct ComplexHeevd {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
// FFI Kernel
namespace eig {
// Eigenvalue Decomposition
lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode);
lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode);
// Hermitian Eigenvalue Decomposition
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode);
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode);
} // namespace eig
template <::xla::ffi::DataType dtype>
struct EigenvalueDecompositionSymmetric {
static_assert(!::xla::ffi::IsComplexType<dtype>(),
"There exists a separate implementation for Complex types");
using ValueType = ::xla::ffi::NativeType<dtype>;
using FnType = void(char* jobz, char* uplo, lapack_int* n, ValueType* a,
lapack_int* lda, ValueType* w, ValueType* work,
lapack_int* lwork, lapack_int* iwork, lapack_int* liwork,
lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> eigenvalues,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> work,
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
eig::ComputationMode mode);
};
template <::xla::ffi::DataType dtype>
struct EigenvalueDecompositionHermitian {
static_assert(::xla::ffi::IsComplexType<dtype>());
using ValueType = ::xla::ffi::NativeType<dtype>;
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
using FnType = void(char* jobz, char* uplo, lapack_int* n, ValueType* a,
lapack_int* lda, RealType* w, ValueType* work,
lapack_int* lwork, RealType* rwork, lapack_int* lrwork,
lapack_int* iwork, lapack_int* liwork, lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> work,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork,
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
eig::ComputationMode mode);
};
// lapack geev
template <typename T>
@ -405,6 +476,68 @@ struct ComplexGeev {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
// FFI Kernel
template <::xla::ffi::DataType dtype>
struct EigenvalueDecomposition {
static_assert(!::xla::ffi::IsComplexType<dtype>(),
"There exists a separate implementation for Complex types");
using ValueType = ::xla::ffi::NativeType<dtype>;
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, ValueType* a,
lapack_int* lda, ValueType* wr, ValueType* wi,
ValueType* vl, lapack_int* ldvl, ValueType* vr,
lapack_int* ldvr, ValueType* work, lapack_int* lwork,
lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
eig::ComputationMode compute_right,
::xla::ffi::ResultBuffer<dtype> eigvals_real,
::xla::ffi::ResultBuffer<dtype> eigvals_imag,
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left,
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> x_work,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right);
static int64_t GetWorkspaceSize(lapack_int x_cols,
eig::ComputationMode compute_left,
eig::ComputationMode compute_right);
};
template <::xla::ffi::DataType dtype>
struct EigenvalueDecompositionComplex {
static_assert(::xla::ffi::IsComplexType<dtype>());
using ValueType = ::xla::ffi::NativeType<dtype>;
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, ValueType* a,
lapack_int* lda, ValueType* w, ValueType* vl,
lapack_int* ldvl, ValueType* vr, lapack_int* ldvr,
ValueType* work, lapack_int* lwork, RealType* rwork,
lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
eig::ComputationMode compute_right,
::xla::ffi::ResultBuffer<dtype> eigvals,
::xla::ffi::ResultBuffer<dtype> eigvecs_left,
::xla::ffi::ResultBuffer<dtype> eigvecs_right,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> x_work,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork);
static int64_t GetWorkspaceSize(lapack_int x_cols,
eig::ComputationMode compute_left,
eig::ComputationMode compute_right);
};
//== Schur Decomposition ==//
// lapack gees
@ -500,6 +633,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssyevd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsyevd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cheevd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zheevd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi);
} // namespace jax

View File

@ -56,15 +56,15 @@ jax::SingularValueDecomposition<ffi::DataType::F64>::FnType dgesdd_;
jax::SingularValueDecompositionComplex<ffi::DataType::C64>::FnType cgesdd_;
jax::SingularValueDecompositionComplex<ffi::DataType::C128>::FnType zgesdd_;
jax::RealSyevd<float>::FnType ssyevd_;
jax::RealSyevd<double>::FnType dsyevd_;
jax::ComplexHeevd<std::complex<float>>::FnType cheevd_;
jax::ComplexHeevd<std::complex<double>>::FnType zheevd_;
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F32>::FnType ssyevd_;
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F64>::FnType dsyevd_;
jax::EigenvalueDecompositionHermitian<ffi::DataType::C64>::FnType cheevd_;
jax::EigenvalueDecompositionHermitian<ffi::DataType::C128>::FnType zheevd_;
jax::RealGeev<float>::FnType sgeev_;
jax::RealGeev<double>::FnType dgeev_;
jax::ComplexGeev<std::complex<float>>::FnType cgeev_;
jax::ComplexGeev<std::complex<double>>::FnType zgeev_;
jax::EigenvalueDecomposition<ffi::DataType::F32>::FnType sgeev_;
jax::EigenvalueDecomposition<ffi::DataType::F64>::FnType dgeev_;
jax::EigenvalueDecompositionComplex<ffi::DataType::C64>::FnType cgeev_;
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType zgeev_;
jax::RealGees<float>::FnType sgees_;
jax::RealGees<double>::FnType dgees_;
@ -173,6 +173,44 @@ static_assert(
jax::SingularValueDecompositionComplex<ffi::DataType::C128>::FnType,
jax::ComplexGesdd<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F32>::FnType,
jax::RealSyevd<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F64>::FnType,
jax::RealSyevd<double>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<
jax::EigenvalueDecompositionHermitian<ffi::DataType::C64>::FnType,
jax::ComplexHeevd<std::complex<float>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<
jax::EigenvalueDecompositionHermitian<ffi::DataType::C128>::FnType,
jax::ComplexHeevd<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::EigenvalueDecomposition<ffi::DataType::F32>::FnType,
jax::RealGeev<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::EigenvalueDecomposition<ffi::DataType::F64>::FnType,
jax::RealGeev<double>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<
jax::EigenvalueDecompositionComplex<ffi::DataType::C64>::FnType,
jax::ComplexGeev<std::complex<float>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType,
jax::ComplexGeev<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG
@ -266,6 +304,17 @@ static auto init = []() -> int {
AssignKernelFn<SingularValueDecompositionComplex<ffi::DataType::C128>>(
zgesdd_);
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F32>>(ssyevd_);
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F64>>(dsyevd_);
AssignKernelFn<EigenvalueDecompositionHermitian<ffi::DataType::C64>>(cheevd_);
AssignKernelFn<EigenvalueDecompositionHermitian<ffi::DataType::C128>>(
zheevd_);
AssignKernelFn<EigenvalueDecomposition<ffi::DataType::F32>>(sgeev_);
AssignKernelFn<EigenvalueDecomposition<ffi::DataType::F64>>(dgeev_);
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C64>>(cgeev_);
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C128>>(zgeev_);
return 0;
}();