Port Schur Decomposition to XLA's FFI

This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685689593
This commit is contained in:
Paweł Paruzel 2024-10-14 06:45:46 -07:00 committed by jax authors
parent ec68d420fe
commit 23fdb91252
7 changed files with 360 additions and 5 deletions

View File

@ -65,6 +65,7 @@ pybind_extension(
module_name = "_lapack",
pytype_srcs = [
"_lapack/__init__.pyi",
"_lapack/schur.pyi",
"_lapack/svd.pyi",
"_lapack/eig.pyi",
],

View File

@ -0,0 +1,26 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from typing import ClassVar
class ComputationMode(enum.Enum):
kComputeSchurVectors: ClassVar[ComputationMode]
kNoComputeSchurVectors: ClassVar[ComputationMode]
class Sort(enum.Enum):
kNoSortEigenvalues: ClassVar[Sort]
kSortEigenvalues: ClassVar[Sort]

View File

@ -153,6 +153,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_ssytrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dsytrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_chetrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zhetrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgees_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi);

View File

@ -137,6 +137,11 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
AssignKernelFn<ComplexGees<std::complex<float>>>(lapack_ptr("cgees"));
AssignKernelFn<ComplexGees<std::complex<double>>>(lapack_ptr("zgees"));
AssignKernelFn<SchurDecomposition<DataType::F32>>(lapack_ptr("sgees"));
AssignKernelFn<SchurDecomposition<DataType::F64>>(lapack_ptr("dgees"));
AssignKernelFn<SchurDecompositionComplex<DataType::C64>>(lapack_ptr("cgees"));
AssignKernelFn<SchurDecompositionComplex<DataType::C128>>(
lapack_ptr("zgees"));
AssignKernelFn<Gehrd<float>>(lapack_ptr("sgehrd"));
AssignKernelFn<Gehrd<double>>(lapack_ptr("dgehrd"));
@ -265,6 +270,10 @@ nb::dict Registrations() {
dict["lapack_dsytrd_ffi"] = EncapsulateFunction(lapack_dsytrd_ffi);
dict["lapack_chetrd_ffi"] = EncapsulateFunction(lapack_chetrd_ffi);
dict["lapack_zhetrd_ffi"] = EncapsulateFunction(lapack_zhetrd_ffi);
dict["lapack_sgees_ffi"] = EncapsulateFunction(lapack_sgees_ffi);
dict["lapack_dgees_ffi"] = EncapsulateFunction(lapack_dgees_ffi);
dict["lapack_cgees_ffi"] = EncapsulateFunction(lapack_cgees_ffi);
dict["lapack_zgees_ffi"] = EncapsulateFunction(lapack_zgees_ffi);
dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi);
dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi);
dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi);
@ -280,6 +289,7 @@ NB_MODULE(_lapack, m) {
// Submodules
auto svd = m.def_submodule("svd");
auto eig = m.def_submodule("eig");
auto schur = m.def_submodule("schur");
// Enums
nb::enum_<svd::ComputationMode>(svd, "ComputationMode")
// kComputeVtOverwriteXPartialU is not implemented
@ -289,6 +299,14 @@ NB_MODULE(_lapack, m) {
nb::enum_<eig::ComputationMode>(eig, "ComputationMode")
.value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors)
.value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors);
nb::enum_<schur::ComputationMode>(schur, "ComputationMode")
.value("kNoComputeSchurVectors",
schur::ComputationMode::kNoComputeSchurVectors)
.value("kComputeSchurVectors",
schur::ComputationMode::kComputeSchurVectors);
nb::enum_<schur::Sort>(schur, "Sort")
.value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues)
.value("kSortEigenvalues", schur::Sort::kSortEigenvalues);
// Old-style LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),

View File

@ -64,6 +64,8 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort);
#undef REGISTER_CHAR_ENUM_ATTR_DECODING
@ -1573,6 +1575,180 @@ template struct RealGees<double>;
template struct ComplexGees<std::complex<float>>;
template struct ComplexGees<std::complex<double>>;
// FFI Kernel
template <ffi::DataType dtype>
ffi::Error SchurDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
ffi::ResultBuffer<dtype> eigvals_real,
ffi::ResultBuffer<dtype> eigvals_imag,
// TODO(paruzelp): Sort is not implemented because select function is not
// supplied. For that reason, this parameter will always be zero!
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
if (sort != schur::Sort::kNoSortEigenvalues) {
return ffi::Error(
ffi::ErrorCode::kUnimplemented,
"Ordering eigenvalues on the diagonal is not implemented");
}
CopyIfDiffBuffer(x, x_out);
// TODO(paruzelp): `select` should be passed as an execution context
bool (*select)(ValueType, ValueType) = nullptr;
ValueType* x_out_data = x_out->typed_data();
ValueType* eigvals_real_data = eigvals_real->typed_data();
ValueType* eigvals_imag_data = eigvals_imag->typed_data();
ValueType* schur_vectors_data = schur_vectors->typed_data();
lapack_int* selected_data = selected_eigvals->typed_data();
lapack_int* info_data = info->typed_data();
auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
// Prepare LAPACK workspaces.
std::unique_ptr<bool[]> bwork =
sort != schur::Sort::kNoSortEigenvalues
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
: nullptr;
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
auto work_data = AllocateScratchMemory<dtype>(work_size);
const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
selected_data, eigvals_real_data, eigvals_imag_data, schur_vectors_data,
&x_cols_v, work_data.get(), &work_size_v, bwork.get(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
x_out_data += x_size;
eigvals_real_data += x_cols;
eigvals_imag_data += x_cols;
schur_vectors_data += x_size;
++selected_data;
++info_data;
}
return ffi::Error::Success();
}
template <ffi::DataType dtype>
ffi::Error SchurDecompositionComplex<dtype>::Kernel(
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
ffi::ResultBuffer<dtype> eigvals,
// TODO(paruzelp): Sort is not implemented because select function is not
// supplied. For that reason, this parameter will always be zero!
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
if (sort != schur::Sort::kNoSortEigenvalues) {
return ffi::Error(
ffi::ErrorCode::kUnimplemented,
"Ordering eigenvalues on the diagonal is not implemented");
}
CopyIfDiffBuffer(x, x_out);
// TODO(paruzelp): `select` should be passed as an execution context
bool (*select)(ValueType) = nullptr;
ValueType* x_out_data = x_out->typed_data();
ValueType* eigvals_data = eigvals->typed_data();
ValueType* schur_vectors_data = schur_vectors->typed_data();
lapack_int* selected_data = selected_eigvals->typed_data();
lapack_int* info_data = info->typed_data();
auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
// Prepare LAPACK workspaces.
std::unique_ptr<bool[]> bwork =
sort != schur::Sort::kNoSortEigenvalues
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
: nullptr;
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
auto work_data = AllocateScratchMemory<dtype>(work_size);
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(x_cols);
const int64_t x_size{x_cols * x_cols};
[[maybe_unused]] const auto x_size_bytes =
static_cast<unsigned long>(x_size) * sizeof(ValueType);
[[maybe_unused]] const auto x_cols_bytes =
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
for (int64_t i = 0; i < batch_count; ++i) {
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
selected_data, eigvals_data, schur_vectors_data, &x_cols_v,
work_data.get(), &work_size_v, rwork_data.get(), bwork.get(), info_data);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));
x_out_data += x_size;
eigvals_data += x_cols;
schur_vectors_data += x_size;
++selected_data;
++info_data;
}
return ffi::Error::Success();
}
template <ffi::DataType dtype>
int64_t SchurDecomposition<dtype>::GetWorkspaceSize(lapack_int x_cols,
schur::ComputationMode mode,
schur::Sort sort) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;
auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
nullptr, nullptr, &x_cols, &optimal_size, &workspace_query, nullptr,
&info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};
template <ffi::DataType dtype>
int64_t SchurDecompositionComplex<dtype>::GetWorkspaceSize(
lapack_int x_cols, schur::ComputationMode mode, schur::Sort sort) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;
auto mode_v = static_cast<char>(mode);
auto sort_v = static_cast<char>(sort);
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
nullptr, &x_cols, &optimal_size, &workspace_query, nullptr, nullptr,
&info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
};
template struct SchurDecomposition<ffi::DataType::F32>;
template struct SchurDecomposition<ffi::DataType::F64>;
template struct SchurDecompositionComplex<ffi::DataType::C64>;
template struct SchurDecompositionComplex<ffi::DataType::C128>;
//== Hessenberg Decomposition ==//
// lapack gehrd
@ -1926,6 +2102,33 @@ template struct TridiagonalReduction<ffi::DataType::C128>;
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_GEES(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SchurDecomposition<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<schur::ComputationMode>("mode") \
.Attr<schur::Sort>("sort") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_real*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_imag*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_GEES_COMPLEX(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, SchurDecompositionComplex<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<schur::ComputationMode>("mode") \
.Attr<schur::Sort>("sort") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_SYTRD_HETRD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, TridiagonalReduction<data_type>::Kernel, \
@ -1998,6 +2201,11 @@ JAX_CPU_DEFINE_SYTRD_HETRD(lapack_dsytrd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_chetrd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_zhetrd_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GEES(lapack_sgees_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEES(lapack_dgees_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_cgees_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_zgees_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64);
@ -2015,6 +2223,8 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_GEEV
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
#undef JAX_CPU_DEFINE_SYTRD_HETRD
#undef JAX_CPU_DEFINE_GEES
#undef JAX_CPU_DEFINE_GEES_COMPLEX
#undef JAX_CPU_DEFINE_GEHRD
} // namespace jax

View File

@ -67,7 +67,18 @@ enum class ComputationMode : char {
kComputeEigenvectors = 'V',
};
}
} // namespace eig
namespace schur {
enum class ComputationMode : char {
kNoComputeSchurVectors = 'N',
kComputeSchurVectors = 'V',
};
enum class Sort : char { kNoSortEigenvalues = 'N', kSortEigenvalues = 'S' };
} // namespace schur
template <typename KernelType>
void AssignKernelFn(void* func) {
@ -96,6 +107,8 @@ DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort);
#undef DEFINE_CHAR_ENUM_ATTR_DECODING
@ -551,6 +564,64 @@ struct ComplexGees {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
// FFI Kernel
template <::xla::ffi::DataType dtype>
struct SchurDecomposition {
static_assert(!::xla::ffi::IsComplexType<dtype>(),
"There exists a separate implementation for Complex types");
using ValueType = ::xla::ffi::NativeType<dtype>;
using FnType = void(char* jobvs, char* sort,
bool (*select)(ValueType, ValueType), lapack_int* n,
ValueType* a, lapack_int* lda, lapack_int* sdim,
ValueType* wr, ValueType* wi, ValueType* vs,
lapack_int* ldvs, ValueType* work, lapack_int* lwork,
bool* bwork, lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, schur::ComputationMode mode,
schur::Sort sort, ::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> schur_vectors,
::xla::ffi::ResultBuffer<dtype> eigvals_real,
::xla::ffi::ResultBuffer<dtype> eigvals_imag,
::xla::ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
::xla::ffi::ResultBuffer<LapackIntDtype> info);
static int64_t GetWorkspaceSize(lapack_int x_cols,
schur::ComputationMode mode,
schur::Sort sort);
};
template <::xla::ffi::DataType dtype>
struct SchurDecompositionComplex {
static_assert(::xla::ffi::IsComplexType<dtype>());
using ValueType = ::xla::ffi::NativeType<dtype>;
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
using FnType = void(char* jobvs, char* sort, bool (*select)(ValueType),
lapack_int* n, ValueType* a, lapack_int* lda,
lapack_int* sdim, ValueType* w, ValueType* vs,
lapack_int* ldvs, ValueType* work, lapack_int* lwork,
RealType* rwork, bool* bwork, lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, schur::ComputationMode mode,
schur::Sort sort, ::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> schur_vectors,
::xla::ffi::ResultBuffer<dtype> eigvals,
::xla::ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
::xla::ffi::ResultBuffer<LapackIntDtype> info);
static int64_t GetWorkspaceSize(lapack_int x_cols,
schur::ComputationMode mode,
schur::Sort sort);
};
//== Hessenberg Decomposition ==//
//== Reduces a non-symmetric square matrix to upper Hessenberg form ==//
@ -677,6 +748,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssytrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsytrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_chetrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zhetrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgees_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgees_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgees_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgees_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi);

View File

@ -66,10 +66,10 @@ jax::EigenvalueDecomposition<ffi::DataType::F64>::FnType dgeev_;
jax::EigenvalueDecompositionComplex<ffi::DataType::C64>::FnType cgeev_;
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType zgeev_;
jax::RealGees<float>::FnType sgees_;
jax::RealGees<double>::FnType dgees_;
jax::ComplexGees<std::complex<float>>::FnType cgees_;
jax::ComplexGees<std::complex<double>>::FnType zgees_;
jax::SchurDecomposition<ffi::DataType::F32>::FnType sgees_;
jax::SchurDecomposition<ffi::DataType::F64>::FnType dgees_;
jax::SchurDecompositionComplex<ffi::DataType::C64>::FnType cgees_;
jax::SchurDecompositionComplex<ffi::DataType::C128>::FnType zgees_;
jax::HessenbergDecomposition<ffi::DataType::F32>::FnType sgehrd_;
jax::HessenbergDecomposition<ffi::DataType::F64>::FnType dgehrd_;
@ -227,6 +227,22 @@ static_assert(
std::is_same_v<jax::TridiagonalReduction<ffi::DataType::C128>::FnType,
jax::Sytrd<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::SchurDecomposition<ffi::DataType::F32>::FnType,
jax::RealGees<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::SchurDecomposition<ffi::DataType::F64>::FnType,
jax::RealGees<double>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::SchurDecompositionComplex<ffi::DataType::C64>::FnType,
jax::ComplexGees<std::complex<float>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::SchurDecompositionComplex<ffi::DataType::C128>::FnType,
jax::ComplexGees<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::HessenbergDecomposition<ffi::DataType::F32>::FnType,
jax::Gehrd<float>::FnType>,
@ -352,6 +368,11 @@ static auto init = []() -> int {
AssignKernelFn<TridiagonalReduction<ffi::DataType::C64>>(chetrd_);
AssignKernelFn<TridiagonalReduction<ffi::DataType::C128>>(zhetrd_);
AssignKernelFn<SchurDecomposition<ffi::DataType::F32>>(sgees_);
AssignKernelFn<SchurDecomposition<ffi::DataType::F64>>(dgees_);
AssignKernelFn<SchurDecompositionComplex<ffi::DataType::C64>>(cgees_);
AssignKernelFn<SchurDecompositionComplex<ffi::DataType::C128>>(zgees_);
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F32>>(sgehrd_);
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F64>>(dgehrd_);
AssignKernelFn<HessenbergDecomposition<ffi::DataType::C64>>(cgehrd_);