mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Port Eigenvalue Decompositions to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 659492696
This commit is contained in:
parent
9b35b760ce
commit
b2a469b361
@ -41,6 +41,7 @@ cc_library(
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -64,6 +65,7 @@ pybind_extension(
|
||||
pytype_srcs = [
|
||||
"_lapack/__init__.pyi",
|
||||
"_lapack/svd.pyi",
|
||||
"_lapack/eig.pyi",
|
||||
],
|
||||
deps = [
|
||||
":lapack_kernels",
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import eig as eig
|
||||
from . import svd as svd
|
||||
|
||||
|
||||
@ -53,6 +54,8 @@ def cgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
def gesdd_iwork_size_ffi(m: int, n: int) -> int: ...
|
||||
def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
def heevd_rwork_size_ffi(n: int) -> int: ...
|
||||
def heevd_work_size_ffi(n: int) -> int: ...
|
||||
def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
@ -62,4 +65,6 @@ def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
def syevd_iwork_size_ffi(n: int) -> int: ...
|
||||
def syevd_work_size_ffi(n: int) -> int: ...
|
||||
def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
|
21
jaxlib/cpu/_lapack/eig.pyi
Normal file
21
jaxlib/cpu/_lapack/eig.pyi
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class ComputationMode(enum.Enum):
|
||||
kComputeEigenvectors: ClassVar[ComputationMode] = ...
|
||||
kNoEigenvectors: ClassVar[ComputationMode] = ...
|
@ -141,6 +141,14 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgesdd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_ssyevd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dsyevd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cheevd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zheevd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi);
|
||||
|
||||
#undef JAX_CPU_REGISTER_HANDLER
|
||||
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
@ -52,6 +52,14 @@ lapack_int GesddGetRealWorkspaceSize(lapack_int m, lapack_int n,
|
||||
return svd::GetRealWorkspaceSize(m, n, mode);
|
||||
}
|
||||
|
||||
// Due to enforced kComputeEigenvectors, this assumes a larger workspace size.
|
||||
// Could be improved to more accurately estimate the expected size based on the
|
||||
// eig::ComputationMode value.
|
||||
template <lapack_int (&f)(int64_t, eig::ComputationMode)>
|
||||
inline constexpr auto BoundWithEigvecs = +[](lapack_int n) {
|
||||
return f(n, eig::ComputationMode::kComputeEigenvectors);
|
||||
};
|
||||
|
||||
void GetLapackKernelsFromScipy() {
|
||||
static bool initialized = false; // Protected by GIL
|
||||
if (initialized) return;
|
||||
@ -128,11 +136,25 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<RealSyevd<double>>(lapack_ptr("dsyevd"));
|
||||
AssignKernelFn<ComplexHeevd<std::complex<float>>>(lapack_ptr("cheevd"));
|
||||
AssignKernelFn<ComplexHeevd<std::complex<double>>>(lapack_ptr("zheevd"));
|
||||
AssignKernelFn<EigenvalueDecompositionSymmetric<DataType::F32>>(
|
||||
lapack_ptr("ssyevd"));
|
||||
AssignKernelFn<EigenvalueDecompositionSymmetric<DataType::F64>>(
|
||||
lapack_ptr("dsyevd"));
|
||||
AssignKernelFn<EigenvalueDecompositionHermitian<DataType::C64>>(
|
||||
lapack_ptr("cheevd"));
|
||||
AssignKernelFn<EigenvalueDecompositionHermitian<DataType::C128>>(
|
||||
lapack_ptr("zheevd"));
|
||||
|
||||
AssignKernelFn<RealGeev<float>>(lapack_ptr("sgeev"));
|
||||
AssignKernelFn<RealGeev<double>>(lapack_ptr("dgeev"));
|
||||
AssignKernelFn<ComplexGeev<std::complex<float>>>(lapack_ptr("cgeev"));
|
||||
AssignKernelFn<ComplexGeev<std::complex<double>>>(lapack_ptr("zgeev"));
|
||||
AssignKernelFn<EigenvalueDecomposition<DataType::F32>>(lapack_ptr("sgeev"));
|
||||
AssignKernelFn<EigenvalueDecomposition<DataType::F64>>(lapack_ptr("dgeev"));
|
||||
AssignKernelFn<EigenvalueDecompositionComplex<DataType::C64>>(
|
||||
lapack_ptr("cgeev"));
|
||||
AssignKernelFn<EigenvalueDecompositionComplex<DataType::C128>>(
|
||||
lapack_ptr("zgeev"));
|
||||
|
||||
AssignKernelFn<RealGees<float>>(lapack_ptr("sgees"));
|
||||
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
|
||||
@ -246,6 +268,14 @@ nb::dict Registrations() {
|
||||
dict["lapack_dgesdd_ffi"] = EncapsulateFunction(lapack_dgesdd_ffi);
|
||||
dict["lapack_cgesdd_ffi"] = EncapsulateFunction(lapack_cgesdd_ffi);
|
||||
dict["lapack_zgesdd_ffi"] = EncapsulateFunction(lapack_zgesdd_ffi);
|
||||
dict["lapack_ssyevd_ffi"] = EncapsulateFunction(lapack_ssyevd_ffi);
|
||||
dict["lapack_dsyevd_ffi"] = EncapsulateFunction(lapack_dsyevd_ffi);
|
||||
dict["lapack_cheevd_ffi"] = EncapsulateFunction(lapack_cheevd_ffi);
|
||||
dict["lapack_zheevd_ffi"] = EncapsulateFunction(lapack_zheevd_ffi);
|
||||
dict["lapack_sgeev_ffi"] = EncapsulateFunction(lapack_sgeev_ffi);
|
||||
dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi);
|
||||
dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi);
|
||||
dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi);
|
||||
|
||||
return dict;
|
||||
}
|
||||
@ -256,12 +286,16 @@ NB_MODULE(_lapack, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
// Submodules
|
||||
auto svd = m.def_submodule("svd");
|
||||
auto eig = m.def_submodule("eig");
|
||||
// Enums
|
||||
nb::enum_<svd::ComputationMode>(svd, "ComputationMode")
|
||||
// kComputeVtOverwriteXPartialU is not implemented
|
||||
.value("kComputeFullUVt", svd::ComputationMode::kComputeFullUVt)
|
||||
.value("kComputeMinUVt", svd::ComputationMode::kComputeMinUVt)
|
||||
.value("kNoComputeUVt", svd::ComputationMode::kNoComputeUVt);
|
||||
nb::enum_<eig::ComputationMode>(eig, "ComputationMode")
|
||||
.value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors)
|
||||
.value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors);
|
||||
|
||||
// Old-style LAPACK Workspace Size Queries
|
||||
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
|
||||
@ -353,6 +387,14 @@ NB_MODULE(_lapack, m) {
|
||||
nb::arg("m"), nb::arg("n"), nb::arg("mode"));
|
||||
m.def("zgesdd_work_size_ffi", &svd::SVDType<DataType::C128>::GetWorkspaceSize,
|
||||
nb::arg("m"), nb::arg("n"), nb::arg("mode"));
|
||||
m.def("syevd_work_size_ffi", BoundWithEigvecs<eig::GetWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
m.def("syevd_iwork_size_ffi", BoundWithEigvecs<eig::GetIntWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
m.def("heevd_work_size_ffi", BoundWithEigvecs<eig::GetComplexWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
m.def("heevd_rwork_size_ffi", BoundWithEigvecs<eig::GetRealWorkspaceSize>,
|
||||
nb::arg("n"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -21,12 +21,15 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "jaxlib/ffi_helpers.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
@ -59,6 +62,7 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
|
||||
|
||||
#undef REGISTER_CHAR_ENUM_ATTR_DECODING
|
||||
|
||||
@ -942,27 +946,155 @@ template struct RealSyevd<double>;
|
||||
template struct ComplexHeevd<std::complex<float>>;
|
||||
template struct ComplexHeevd<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(2 * x_cols + 1);
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(1 + 6 * x_cols + 2 * x_cols * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return 1;
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(3 + 5 * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error EigenvalueDecompositionSymmetric<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> eigenvalues,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
|
||||
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* eigenvalues_data = eigenvalues->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
auto* work_data = work->typed_data();
|
||||
auto* iwork_data = iwork->typed_data();
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
work->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
iwork->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
const int64_t x_out_step{x_cols * x_cols};
|
||||
const int64_t eigenvalues_step{x_cols};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
|
||||
eigenvalues_data, work_data, &workspace_dim_v, iwork_data,
|
||||
&iworkspace_dim_v, info_data);
|
||||
x_out_data += x_out_step;
|
||||
eigenvalues_data += eigenvalues_step;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
namespace eig {
|
||||
|
||||
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(x_cols + 1);
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(2 * x_cols + x_cols * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) {
|
||||
switch (mode) {
|
||||
case ComputationMode::kNoEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(std::max(x_cols, int64_t{1}));
|
||||
case ComputationMode::kComputeEigenvectors:
|
||||
return CastNoOverflow<lapack_int>(1 + 5 * x_cols + 2 * x_cols * x_cols);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace eig
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error EigenvalueDecompositionHermitian<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> eigenvalues,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> work,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork,
|
||||
ffi::ResultBuffer<LapackIntDtype> iwork, eig::ComputationMode mode) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* eigenvalues_data = eigenvalues->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
auto* work_data = work->typed_data();
|
||||
auto* iwork_data = iwork->typed_data();
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
work->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
rwork->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
iwork->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
const int64_t x_out_step{x_cols * x_cols};
|
||||
const int64_t eigenvalues_step{x_cols};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v,
|
||||
eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(),
|
||||
&rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data);
|
||||
x_out_data += x_out_step;
|
||||
eigenvalues_data += eigenvalues_step;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template struct EigenvalueDecompositionSymmetric<ffi::DataType::F32>;
|
||||
template struct EigenvalueDecompositionSymmetric<ffi::DataType::F64>;
|
||||
template struct EigenvalueDecompositionHermitian<ffi::DataType::C64>;
|
||||
template struct EigenvalueDecompositionHermitian<ffi::DataType::C128>;
|
||||
|
||||
// LAPACK uses a packed representation to represent a mixture of real
|
||||
// eigenvectors and complex conjugate pairs. This helper unpacks the
|
||||
// representation into regular complex matrices.
|
||||
template <typename T>
|
||||
static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed,
|
||||
std::complex<T>* unpacked) {
|
||||
T re, im;
|
||||
int j;
|
||||
j = 0;
|
||||
while (j < n) {
|
||||
if (im_eigenvalues[j] == 0. || std::isnan(im_eigenvalues[j])) {
|
||||
for (int k = 0; k < n; ++k) {
|
||||
unpacked[j * n + k] = {packed[j * n + k], 0.};
|
||||
static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag,
|
||||
const T* packed, std::complex<T>* unpacked) {
|
||||
for (int j = 0; j < n;) {
|
||||
if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
|
||||
// Real values in each row without imaginary part
|
||||
// Second row of the imaginary part is not provided
|
||||
for (int i = 0; i < n; ++i) {
|
||||
unpacked[j * n + i] = {packed[j * n + i], 0.};
|
||||
}
|
||||
++j;
|
||||
} else {
|
||||
for (int k = 0; k < n; ++k) {
|
||||
re = packed[j * n + k];
|
||||
im = packed[(j + 1) * n + k];
|
||||
unpacked[j * n + k] = {re, im};
|
||||
unpacked[(j + 1) * n + k] = {re, -im};
|
||||
// Complex values where the real part is in the jth row
|
||||
// and the imaginary part is in the next row (j + 1)
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const T real_part = packed[j * n + i];
|
||||
const T imag_part = packed[(j + 1) * n + i];
|
||||
unpacked[j * n + i] = {real_part, imag_part};
|
||||
unpacked[(j + 1) * n + i] = {real_part, -imag_part};
|
||||
}
|
||||
j += 2;
|
||||
}
|
||||
@ -1116,6 +1248,183 @@ template struct RealGeev<double>;
|
||||
template struct ComplexGeev<std::complex<float>>;
|
||||
template struct ComplexGeev<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error EigenvalueDecomposition<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right, ffi::ResultBuffer<dtype> eigvals_real,
|
||||
ffi::ResultBuffer<dtype> eigvals_imag,
|
||||
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_left,
|
||||
ffi::ResultBuffer<ffi::ToComplex(dtype)> eigvecs_right,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_left,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> work_eigvecs_right) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
|
||||
const auto* x_data = x.typed_data();
|
||||
auto* x_work_data = x_work->typed_data();
|
||||
auto* work_eigvecs_left_data = work_eigvecs_left->typed_data();
|
||||
auto* work_eigvecs_right_data = work_eigvecs_right->typed_data();
|
||||
auto* eigvecs_left_data = eigvecs_left->typed_data();
|
||||
auto* eigvecs_right_data = eigvecs_right->typed_data();
|
||||
auto* eigvals_real_data = eigvals_real->typed_data();
|
||||
auto* eigvals_imag_data = eigvals_imag->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
|
||||
auto compute_left_v = static_cast<char>(compute_left);
|
||||
auto compute_right_v = static_cast<char>(compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto work_size_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
// TODO(phawkins): preallocate workspace using XLA.
|
||||
auto work = std::make_unique<ValueType[]>(work_size);
|
||||
auto* work_data = work.get();
|
||||
|
||||
const auto is_finite = [](ValueType* data, int64_t size) {
|
||||
return absl::c_all_of(absl::MakeSpan(data, size),
|
||||
[](ValueType value) { return std::isfinite(value); });
|
||||
};
|
||||
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
[[maybe_unused]] const auto x_size_bytes =
|
||||
static_cast<unsigned long>(x_size) * sizeof(ValueType);
|
||||
[[maybe_unused]] const auto x_cols_bytes =
|
||||
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
std::copy_n(x_data, x_size, x_work_data);
|
||||
if (is_finite(x_work_data, x_size)) {
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
|
||||
eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data,
|
||||
&x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v,
|
||||
info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data,
|
||||
x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
|
||||
if (info_data[0] == 0) {
|
||||
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data,
|
||||
eigvecs_left_data);
|
||||
UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data,
|
||||
eigvecs_right_data);
|
||||
}
|
||||
} else {
|
||||
info_data[0] = -4;
|
||||
}
|
||||
x_data += x_size;
|
||||
eigvals_real_data += x_cols;
|
||||
eigvals_imag_data += x_cols;
|
||||
eigvecs_left_data += x_size;
|
||||
eigvecs_right_data += x_size;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error EigenvalueDecompositionComplex<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right, ffi::ResultBuffer<dtype> eigvals,
|
||||
ffi::ResultBuffer<dtype> eigvecs_left,
|
||||
ffi::ResultBuffer<dtype> eigvecs_right,
|
||||
ffi::ResultBuffer<LapackIntDtype> info, ffi::ResultBuffer<dtype> x_work,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> rwork) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
const auto* x_data = x.typed_data();
|
||||
auto* x_work_data = x_work->typed_data();
|
||||
auto* eigvecs_left_data = eigvecs_left->typed_data();
|
||||
auto* eigvecs_right_data = eigvecs_right->typed_data();
|
||||
auto* eigvals_data = eigvals->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
|
||||
auto compute_left_v = static_cast<char>(compute_left);
|
||||
auto compute_right_v = static_cast<char>(compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right);
|
||||
FFI_ASSIGN_OR_RETURN(auto work_size_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
// TODO(phawkins): preallocate workspace using XLA.
|
||||
auto work = std::make_unique<ValueType[]>(work_size);
|
||||
auto* work_data = work.get();
|
||||
|
||||
const auto is_finite = [](ValueType* data, int64_t size) {
|
||||
return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) {
|
||||
return std::isfinite(z.real()) && std::isfinite(z.imag());
|
||||
});
|
||||
};
|
||||
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
[[maybe_unused]] const auto x_size_bytes =
|
||||
static_cast<unsigned long>(x_size) * sizeof(ValueType);
|
||||
[[maybe_unused]] const auto x_cols_bytes =
|
||||
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
std::copy_n(x_data, x_size, x_work_data);
|
||||
if (is_finite(x_work_data, x_size)) {
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v,
|
||||
eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data,
|
||||
&x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
|
||||
} else {
|
||||
info_data[0] = -4;
|
||||
}
|
||||
x_data += x_size;
|
||||
eigvals_data += x_cols;
|
||||
eigvecs_left_data += x_size;
|
||||
eigvecs_right_data += x_size;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
int64_t EigenvalueDecomposition<dtype>::GetWorkspaceSize(
|
||||
lapack_int x_cols, eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right) {
|
||||
ValueType optimal_size = {};
|
||||
lapack_int workspace_query = -1;
|
||||
lapack_int info = 0;
|
||||
|
||||
auto compute_left_v = static_cast<char>(compute_left);
|
||||
auto compute_right_v = static_cast<char>(compute_right);
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols, nullptr, &x_cols, nullptr,
|
||||
nullptr, nullptr, &x_cols, nullptr, &x_cols, &optimal_size,
|
||||
&workspace_query, &info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
||||
};
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
int64_t EigenvalueDecompositionComplex<dtype>::GetWorkspaceSize(
|
||||
lapack_int x_cols, eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right) {
|
||||
ValueType optimal_size = {};
|
||||
lapack_int workspace_query = -1;
|
||||
lapack_int info = 0;
|
||||
// NULL rwork crashes, LAPACK unnecessarily writes x_cols into rwork
|
||||
RealType rwork[1];
|
||||
auto compute_left_v = static_cast<char>(compute_left);
|
||||
auto compute_right_v = static_cast<char>(compute_right);
|
||||
fn(&compute_left_v, &compute_right_v, &x_cols, nullptr, &x_cols, nullptr,
|
||||
nullptr, &x_cols, nullptr, &x_cols, &optimal_size, &workspace_query, rwork,
|
||||
&info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
||||
};
|
||||
|
||||
template struct EigenvalueDecomposition<ffi::DataType::F32>;
|
||||
template struct EigenvalueDecomposition<ffi::DataType::F64>;
|
||||
template struct EigenvalueDecompositionComplex<ffi::DataType::C64>;
|
||||
template struct EigenvalueDecompositionComplex<ffi::DataType::C128>;
|
||||
|
||||
//== Schur Decomposition ==//
|
||||
|
||||
// lapack gees
|
||||
@ -1445,6 +1754,68 @@ template struct Sytrd<std::complex<double>>;
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
|
||||
.Attr<svd::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_SYEVD(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, EigenvalueDecompositionSymmetric<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<MatrixParams::UpLo>("uplo") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigenvalues*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
|
||||
.Attr<eig::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_HEEVD(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, EigenvalueDecompositionHermitian<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<MatrixParams::UpLo>("uplo") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*eigenvalues*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*iwork*/) \
|
||||
.Attr<eig::ComputationMode>("mode"))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEEV(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, EigenvalueDecomposition<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<eig::ComputationMode>("compute_left") \
|
||||
.Attr<eig::ComputationMode>("compute_right") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_real*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_imag*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \
|
||||
/*eigvecs_left*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \
|
||||
/*eigvecs_right*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*work_eigvecs_left*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*work_eigvecs_right*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, EigenvalueDecompositionComplex<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<eig::ComputationMode>("compute_left") \
|
||||
.Attr<eig::ComputationMode>("compute_right") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_left*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_work*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/))
|
||||
|
||||
// FFI Handlers
|
||||
|
||||
JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
|
||||
@ -1477,6 +1848,16 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_SYEVD(lapack_ssyevd_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_SYEVD(lapack_dsyevd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_HEEVD(lapack_cheevd_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_HEEVD(lapack_zheevd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_GEEV(lapack_sgeev_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
#undef JAX_CPU_DEFINE_TRSM
|
||||
#undef JAX_CPU_DEFINE_GETRF
|
||||
#undef JAX_CPU_DEFINE_GEQRF
|
||||
@ -1484,5 +1865,9 @@ JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
|
||||
#undef JAX_CPU_DEFINE_POTRF
|
||||
#undef JAX_CPU_DEFINE_GESDD
|
||||
#undef JAX_CPU_DEFINE_GESDD_COMPLEX
|
||||
#undef JAX_CPU_DEFINE_SYEVD
|
||||
#undef JAX_CPU_DEFINE_HEEVD
|
||||
#undef JAX_CPU_DEFINE_GEEV
|
||||
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
|
||||
|
||||
} // namespace jax
|
||||
|
@ -59,6 +59,15 @@ inline bool ComputesUV(ComputationMode mode) {
|
||||
|
||||
} // namespace svd
|
||||
|
||||
namespace eig {
|
||||
|
||||
enum class ComputationMode : char {
|
||||
kNoEigenvectors = 'N',
|
||||
kComputeEigenvectors = 'V',
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
template <typename KernelType>
|
||||
void AssignKernelFn(void* func) {
|
||||
KernelType::fn = reinterpret_cast<typename KernelType::FnType*>(func);
|
||||
@ -85,6 +94,7 @@ DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
|
||||
|
||||
#undef DEFINE_CHAR_ENUM_ATTR_DECODING
|
||||
|
||||
@ -383,6 +393,67 @@ struct ComplexHeevd {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
namespace eig {
|
||||
|
||||
// Eigenvalue Decomposition
|
||||
lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
|
||||
// Hermitian Eigenvalue Decomposition
|
||||
lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode);
|
||||
|
||||
} // namespace eig
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct EigenvalueDecompositionSymmetric {
|
||||
static_assert(!::xla::ffi::IsComplexType<dtype>(),
|
||||
"There exists a separate implementation for Complex types");
|
||||
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using FnType = void(char* jobz, char* uplo, lapack_int* n, ValueType* a,
|
||||
lapack_int* lda, ValueType* w, ValueType* work,
|
||||
lapack_int* lwork, lapack_int* iwork, lapack_int* liwork,
|
||||
lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<dtype> eigenvalues,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> work,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
|
||||
eig::ComputationMode mode);
|
||||
};
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct EigenvalueDecompositionHermitian {
|
||||
static_assert(::xla::ffi::IsComplexType<dtype>());
|
||||
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
|
||||
using FnType = void(char* jobz, char* uplo, lapack_int* n, ValueType* a,
|
||||
lapack_int* lda, RealType* w, ValueType* work,
|
||||
lapack_int* lwork, RealType* rwork, lapack_int* lrwork,
|
||||
lapack_int* iwork, lapack_int* liwork, lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> work,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> iwork,
|
||||
eig::ComputationMode mode);
|
||||
};
|
||||
|
||||
// lapack geev
|
||||
|
||||
template <typename T>
|
||||
@ -405,6 +476,68 @@ struct ComplexGeev {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct EigenvalueDecomposition {
|
||||
static_assert(!::xla::ffi::IsComplexType<dtype>(),
|
||||
"There exists a separate implementation for Complex types");
|
||||
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, ValueType* a,
|
||||
lapack_int* lda, ValueType* wr, ValueType* wi,
|
||||
ValueType* vl, lapack_int* ldvl, ValueType* vr,
|
||||
lapack_int* ldvr, ValueType* work, lapack_int* lwork,
|
||||
lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals_real,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals_imag,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> x_work,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_cols,
|
||||
eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right);
|
||||
};
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct EigenvalueDecompositionComplex {
|
||||
static_assert(::xla::ffi::IsComplexType<dtype>());
|
||||
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
|
||||
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, ValueType* a,
|
||||
lapack_int* lda, ValueType* w, ValueType* vl,
|
||||
lapack_int* ldvl, ValueType* vr, lapack_int* ldvr,
|
||||
ValueType* work, lapack_int* lwork, RealType* rwork,
|
||||
lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvecs_left,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvecs_right,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> x_work,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_cols,
|
||||
eig::ComputationMode compute_left,
|
||||
eig::ComputationMode compute_right);
|
||||
};
|
||||
|
||||
//== Schur Decomposition ==//
|
||||
|
||||
// lapack gees
|
||||
@ -500,6 +633,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgesdd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssyevd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsyevd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cheevd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zheevd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi);
|
||||
|
||||
} // namespace jax
|
||||
|
||||
|
@ -56,15 +56,15 @@ jax::SingularValueDecomposition<ffi::DataType::F64>::FnType dgesdd_;
|
||||
jax::SingularValueDecompositionComplex<ffi::DataType::C64>::FnType cgesdd_;
|
||||
jax::SingularValueDecompositionComplex<ffi::DataType::C128>::FnType zgesdd_;
|
||||
|
||||
jax::RealSyevd<float>::FnType ssyevd_;
|
||||
jax::RealSyevd<double>::FnType dsyevd_;
|
||||
jax::ComplexHeevd<std::complex<float>>::FnType cheevd_;
|
||||
jax::ComplexHeevd<std::complex<double>>::FnType zheevd_;
|
||||
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F32>::FnType ssyevd_;
|
||||
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F64>::FnType dsyevd_;
|
||||
jax::EigenvalueDecompositionHermitian<ffi::DataType::C64>::FnType cheevd_;
|
||||
jax::EigenvalueDecompositionHermitian<ffi::DataType::C128>::FnType zheevd_;
|
||||
|
||||
jax::RealGeev<float>::FnType sgeev_;
|
||||
jax::RealGeev<double>::FnType dgeev_;
|
||||
jax::ComplexGeev<std::complex<float>>::FnType cgeev_;
|
||||
jax::ComplexGeev<std::complex<double>>::FnType zgeev_;
|
||||
jax::EigenvalueDecomposition<ffi::DataType::F32>::FnType sgeev_;
|
||||
jax::EigenvalueDecomposition<ffi::DataType::F64>::FnType dgeev_;
|
||||
jax::EigenvalueDecompositionComplex<ffi::DataType::C64>::FnType cgeev_;
|
||||
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType zgeev_;
|
||||
|
||||
jax::RealGees<float>::FnType sgees_;
|
||||
jax::RealGees<double>::FnType dgees_;
|
||||
@ -173,6 +173,44 @@ static_assert(
|
||||
jax::SingularValueDecompositionComplex<ffi::DataType::C128>::FnType,
|
||||
jax::ComplexGesdd<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<
|
||||
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F32>::FnType,
|
||||
jax::RealSyevd<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<
|
||||
jax::EigenvalueDecompositionSymmetric<ffi::DataType::F64>::FnType,
|
||||
jax::RealSyevd<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<
|
||||
jax::EigenvalueDecompositionHermitian<ffi::DataType::C64>::FnType,
|
||||
jax::ComplexHeevd<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<
|
||||
jax::EigenvalueDecompositionHermitian<ffi::DataType::C128>::FnType,
|
||||
jax::ComplexHeevd<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::EigenvalueDecomposition<ffi::DataType::F32>::FnType,
|
||||
jax::RealGeev<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::EigenvalueDecomposition<ffi::DataType::F64>::FnType,
|
||||
jax::RealGeev<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<
|
||||
jax::EigenvalueDecompositionComplex<ffi::DataType::C64>::FnType,
|
||||
jax::ComplexGeev<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<
|
||||
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType,
|
||||
jax::ComplexGeev<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
|
||||
#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG
|
||||
|
||||
@ -266,6 +304,17 @@ static auto init = []() -> int {
|
||||
AssignKernelFn<SingularValueDecompositionComplex<ffi::DataType::C128>>(
|
||||
zgesdd_);
|
||||
|
||||
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F32>>(ssyevd_);
|
||||
AssignKernelFn<EigenvalueDecompositionSymmetric<ffi::DataType::F64>>(dsyevd_);
|
||||
AssignKernelFn<EigenvalueDecompositionHermitian<ffi::DataType::C64>>(cheevd_);
|
||||
AssignKernelFn<EigenvalueDecompositionHermitian<ffi::DataType::C128>>(
|
||||
zheevd_);
|
||||
|
||||
AssignKernelFn<EigenvalueDecomposition<ffi::DataType::F32>>(sgeev_);
|
||||
AssignKernelFn<EigenvalueDecomposition<ffi::DataType::F64>>(dgeev_);
|
||||
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C64>>(cgeev_);
|
||||
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C128>>(zgeev_);
|
||||
|
||||
return 0;
|
||||
}();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user