mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Port Schur Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 685689593
This commit is contained in:
parent
ec68d420fe
commit
23fdb91252
@ -65,6 +65,7 @@ pybind_extension(
|
||||
module_name = "_lapack",
|
||||
pytype_srcs = [
|
||||
"_lapack/__init__.pyi",
|
||||
"_lapack/schur.pyi",
|
||||
"_lapack/svd.pyi",
|
||||
"_lapack/eig.pyi",
|
||||
],
|
||||
|
26
jaxlib/cpu/_lapack/schur.pyi
Normal file
26
jaxlib/cpu/_lapack/schur.pyi
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class ComputationMode(enum.Enum):
|
||||
kComputeSchurVectors: ClassVar[ComputationMode]
|
||||
kNoComputeSchurVectors: ClassVar[ComputationMode]
|
||||
|
||||
|
||||
class Sort(enum.Enum):
|
||||
kNoSortEigenvalues: ClassVar[Sort]
|
||||
kSortEigenvalues: ClassVar[Sort]
|
@ -153,6 +153,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_ssytrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dsytrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_chetrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zhetrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sgees_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgees_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgees_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgees_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi);
|
||||
|
@ -137,6 +137,11 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<RealGees<double>>(lapack_ptr("dgees"));
|
||||
AssignKernelFn<ComplexGees<std::complex<float>>>(lapack_ptr("cgees"));
|
||||
AssignKernelFn<ComplexGees<std::complex<double>>>(lapack_ptr("zgees"));
|
||||
AssignKernelFn<SchurDecomposition<DataType::F32>>(lapack_ptr("sgees"));
|
||||
AssignKernelFn<SchurDecomposition<DataType::F64>>(lapack_ptr("dgees"));
|
||||
AssignKernelFn<SchurDecompositionComplex<DataType::C64>>(lapack_ptr("cgees"));
|
||||
AssignKernelFn<SchurDecompositionComplex<DataType::C128>>(
|
||||
lapack_ptr("zgees"));
|
||||
|
||||
AssignKernelFn<Gehrd<float>>(lapack_ptr("sgehrd"));
|
||||
AssignKernelFn<Gehrd<double>>(lapack_ptr("dgehrd"));
|
||||
@ -265,6 +270,10 @@ nb::dict Registrations() {
|
||||
dict["lapack_dsytrd_ffi"] = EncapsulateFunction(lapack_dsytrd_ffi);
|
||||
dict["lapack_chetrd_ffi"] = EncapsulateFunction(lapack_chetrd_ffi);
|
||||
dict["lapack_zhetrd_ffi"] = EncapsulateFunction(lapack_zhetrd_ffi);
|
||||
dict["lapack_sgees_ffi"] = EncapsulateFunction(lapack_sgees_ffi);
|
||||
dict["lapack_dgees_ffi"] = EncapsulateFunction(lapack_dgees_ffi);
|
||||
dict["lapack_cgees_ffi"] = EncapsulateFunction(lapack_cgees_ffi);
|
||||
dict["lapack_zgees_ffi"] = EncapsulateFunction(lapack_zgees_ffi);
|
||||
dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi);
|
||||
dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi);
|
||||
dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi);
|
||||
@ -280,6 +289,7 @@ NB_MODULE(_lapack, m) {
|
||||
// Submodules
|
||||
auto svd = m.def_submodule("svd");
|
||||
auto eig = m.def_submodule("eig");
|
||||
auto schur = m.def_submodule("schur");
|
||||
// Enums
|
||||
nb::enum_<svd::ComputationMode>(svd, "ComputationMode")
|
||||
// kComputeVtOverwriteXPartialU is not implemented
|
||||
@ -289,6 +299,14 @@ NB_MODULE(_lapack, m) {
|
||||
nb::enum_<eig::ComputationMode>(eig, "ComputationMode")
|
||||
.value("kComputeEigenvectors", eig::ComputationMode::kComputeEigenvectors)
|
||||
.value("kNoEigenvectors", eig::ComputationMode::kNoEigenvectors);
|
||||
nb::enum_<schur::ComputationMode>(schur, "ComputationMode")
|
||||
.value("kNoComputeSchurVectors",
|
||||
schur::ComputationMode::kNoComputeSchurVectors)
|
||||
.value("kComputeSchurVectors",
|
||||
schur::ComputationMode::kComputeSchurVectors);
|
||||
nb::enum_<schur::Sort>(schur, "Sort")
|
||||
.value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues)
|
||||
.value("kSortEigenvalues", schur::Sort::kSortEigenvalues);
|
||||
|
||||
// Old-style LAPACK Workspace Size Queries
|
||||
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
|
||||
|
@ -64,6 +64,8 @@ REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort);
|
||||
|
||||
#undef REGISTER_CHAR_ENUM_ATTR_DECODING
|
||||
|
||||
@ -1573,6 +1575,180 @@ template struct RealGees<double>;
|
||||
template struct ComplexGees<std::complex<float>>;
|
||||
template struct ComplexGees<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error SchurDecomposition<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
|
||||
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
|
||||
ffi::ResultBuffer<dtype> eigvals_real,
|
||||
ffi::ResultBuffer<dtype> eigvals_imag,
|
||||
// TODO(paruzelp): Sort is not implemented because select function is not
|
||||
// supplied. For that reason, this parameter will always be zero!
|
||||
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
|
||||
ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
if (sort != schur::Sort::kNoSortEigenvalues) {
|
||||
return ffi::Error(
|
||||
ffi::ErrorCode::kUnimplemented,
|
||||
"Ordering eigenvalues on the diagonal is not implemented");
|
||||
}
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
// TODO(paruzelp): `select` should be passed as an execution context
|
||||
bool (*select)(ValueType, ValueType) = nullptr;
|
||||
ValueType* x_out_data = x_out->typed_data();
|
||||
ValueType* eigvals_real_data = eigvals_real->typed_data();
|
||||
ValueType* eigvals_imag_data = eigvals_imag->typed_data();
|
||||
ValueType* schur_vectors_data = schur_vectors->typed_data();
|
||||
lapack_int* selected_data = selected_eigvals->typed_data();
|
||||
lapack_int* info_data = info->typed_data();
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto sort_v = static_cast<char>(sort);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
// Prepare LAPACK workspaces.
|
||||
std::unique_ptr<bool[]> bwork =
|
||||
sort != schur::Sort::kNoSortEigenvalues
|
||||
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
|
||||
: nullptr;
|
||||
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
|
||||
FFI_ASSIGN_OR_RETURN(auto work_size_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size);
|
||||
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
[[maybe_unused]] const auto x_size_bytes =
|
||||
static_cast<unsigned long>(x_size) * sizeof(ValueType);
|
||||
[[maybe_unused]] const auto x_cols_bytes =
|
||||
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
|
||||
selected_data, eigvals_real_data, eigvals_imag_data, schur_vectors_data,
|
||||
&x_cols_v, work_data.get(), &work_size_v, bwork.get(), info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
|
||||
|
||||
x_out_data += x_size;
|
||||
eigvals_real_data += x_cols;
|
||||
eigvals_imag_data += x_cols;
|
||||
schur_vectors_data += x_size;
|
||||
++selected_data;
|
||||
++info_data;
|
||||
}
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error SchurDecompositionComplex<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, schur::ComputationMode mode, schur::Sort sort,
|
||||
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> schur_vectors,
|
||||
ffi::ResultBuffer<dtype> eigvals,
|
||||
// TODO(paruzelp): Sort is not implemented because select function is not
|
||||
// supplied. For that reason, this parameter will always be zero!
|
||||
ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
|
||||
ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
if (sort != schur::Sort::kNoSortEigenvalues) {
|
||||
return ffi::Error(
|
||||
ffi::ErrorCode::kUnimplemented,
|
||||
"Ordering eigenvalues on the diagonal is not implemented");
|
||||
}
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
// TODO(paruzelp): `select` should be passed as an execution context
|
||||
bool (*select)(ValueType) = nullptr;
|
||||
ValueType* x_out_data = x_out->typed_data();
|
||||
ValueType* eigvals_data = eigvals->typed_data();
|
||||
ValueType* schur_vectors_data = schur_vectors->typed_data();
|
||||
lapack_int* selected_data = selected_eigvals->typed_data();
|
||||
lapack_int* info_data = info->typed_data();
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto sort_v = static_cast<char>(sort);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
// Prepare LAPACK workspaces.
|
||||
std::unique_ptr<bool[]> bwork =
|
||||
sort != schur::Sort::kNoSortEigenvalues
|
||||
? AllocateScratchMemory<ffi::DataType::PRED>(x_cols)
|
||||
: nullptr;
|
||||
auto work_size = GetWorkspaceSize(x_cols, mode, sort);
|
||||
FFI_ASSIGN_OR_RETURN(auto work_size_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size);
|
||||
auto rwork_data = AllocateScratchMemory<ffi::ToReal(dtype)>(x_cols);
|
||||
|
||||
const int64_t x_size{x_cols * x_cols};
|
||||
[[maybe_unused]] const auto x_size_bytes =
|
||||
static_cast<unsigned long>(x_size) * sizeof(ValueType);
|
||||
[[maybe_unused]] const auto x_cols_bytes =
|
||||
static_cast<unsigned long>(x_cols) * sizeof(ValueType);
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&mode_v, &sort_v, select, &x_cols_v, x_out_data, &x_cols_v,
|
||||
selected_data, eigvals_data, schur_vectors_data, &x_cols_v,
|
||||
work_data.get(), &work_size_v, rwork_data.get(), bwork.get(), info_data);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(schur_vectors_data, x_size_bytes);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int));
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(selected_data, sizeof(lapack_int));
|
||||
|
||||
x_out_data += x_size;
|
||||
eigvals_data += x_cols;
|
||||
schur_vectors_data += x_size;
|
||||
++selected_data;
|
||||
++info_data;
|
||||
}
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
int64_t SchurDecomposition<dtype>::GetWorkspaceSize(lapack_int x_cols,
|
||||
schur::ComputationMode mode,
|
||||
schur::Sort sort) {
|
||||
ValueType optimal_size = {};
|
||||
lapack_int workspace_query = -1;
|
||||
lapack_int info = 0;
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto sort_v = static_cast<char>(sort);
|
||||
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
|
||||
nullptr, nullptr, &x_cols, &optimal_size, &workspace_query, nullptr,
|
||||
&info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
||||
};
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
int64_t SchurDecompositionComplex<dtype>::GetWorkspaceSize(
|
||||
lapack_int x_cols, schur::ComputationMode mode, schur::Sort sort) {
|
||||
ValueType optimal_size = {};
|
||||
lapack_int workspace_query = -1;
|
||||
lapack_int info = 0;
|
||||
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto sort_v = static_cast<char>(sort);
|
||||
fn(&mode_v, &sort_v, nullptr, &x_cols, nullptr, &x_cols, nullptr, nullptr,
|
||||
nullptr, &x_cols, &optimal_size, &workspace_query, nullptr, nullptr,
|
||||
&info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
||||
};
|
||||
|
||||
template struct SchurDecomposition<ffi::DataType::F32>;
|
||||
template struct SchurDecomposition<ffi::DataType::F64>;
|
||||
template struct SchurDecompositionComplex<ffi::DataType::C64>;
|
||||
template struct SchurDecompositionComplex<ffi::DataType::C128>;
|
||||
|
||||
//== Hessenberg Decomposition ==//
|
||||
|
||||
// lapack gehrd
|
||||
@ -1926,6 +2102,33 @@ template struct TridiagonalReduction<ffi::DataType::C128>;
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEES(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, SchurDecomposition<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<schur::ComputationMode>("mode") \
|
||||
.Attr<schur::Sort>("sort") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_real*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals_imag*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEES_COMPLEX(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, SchurDecompositionComplex<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<schur::ComputationMode>("mode") \
|
||||
.Attr<schur::Sort>("sort") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*schur_vectors*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvals*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*selected_eigvals*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_SYTRD_HETRD(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, TridiagonalReduction<data_type>::Kernel, \
|
||||
@ -1998,6 +2201,11 @@ JAX_CPU_DEFINE_SYTRD_HETRD(lapack_dsytrd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_chetrd_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_zhetrd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_GEES(lapack_sgees_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_GEES(lapack_dgees_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_cgees_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GEES_COMPLEX(lapack_zgees_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64);
|
||||
@ -2015,6 +2223,8 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128);
|
||||
#undef JAX_CPU_DEFINE_GEEV
|
||||
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
|
||||
#undef JAX_CPU_DEFINE_SYTRD_HETRD
|
||||
#undef JAX_CPU_DEFINE_GEES
|
||||
#undef JAX_CPU_DEFINE_GEES_COMPLEX
|
||||
#undef JAX_CPU_DEFINE_GEHRD
|
||||
|
||||
} // namespace jax
|
||||
|
@ -67,7 +67,18 @@ enum class ComputationMode : char {
|
||||
kComputeEigenvectors = 'V',
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace eig
|
||||
|
||||
namespace schur {
|
||||
|
||||
enum class ComputationMode : char {
|
||||
kNoComputeSchurVectors = 'N',
|
||||
kComputeSchurVectors = 'V',
|
||||
};
|
||||
|
||||
enum class Sort : char { kNoSortEigenvalues = 'N', kSortEigenvalues = 'S' };
|
||||
|
||||
} // namespace schur
|
||||
|
||||
template <typename KernelType>
|
||||
void AssignKernelFn(void* func) {
|
||||
@ -96,6 +107,8 @@ DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort);
|
||||
|
||||
#undef DEFINE_CHAR_ENUM_ATTR_DECODING
|
||||
|
||||
@ -551,6 +564,64 @@ struct ComplexGees {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct SchurDecomposition {
|
||||
static_assert(!::xla::ffi::IsComplexType<dtype>(),
|
||||
"There exists a separate implementation for Complex types");
|
||||
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using FnType = void(char* jobvs, char* sort,
|
||||
bool (*select)(ValueType, ValueType), lapack_int* n,
|
||||
ValueType* a, lapack_int* lda, lapack_int* sdim,
|
||||
ValueType* wr, ValueType* wi, ValueType* vs,
|
||||
lapack_int* ldvs, ValueType* work, lapack_int* lwork,
|
||||
bool* bwork, lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, schur::ComputationMode mode,
|
||||
schur::Sort sort, ::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<dtype> schur_vectors,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals_real,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals_imag,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_cols,
|
||||
schur::ComputationMode mode,
|
||||
schur::Sort sort);
|
||||
};
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct SchurDecompositionComplex {
|
||||
static_assert(::xla::ffi::IsComplexType<dtype>());
|
||||
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
|
||||
using FnType = void(char* jobvs, char* sort, bool (*select)(ValueType),
|
||||
lapack_int* n, ValueType* a, lapack_int* lda,
|
||||
lapack_int* sdim, ValueType* w, ValueType* vs,
|
||||
lapack_int* ldvs, ValueType* work, lapack_int* lwork,
|
||||
RealType* rwork, bool* bwork, lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, schur::ComputationMode mode,
|
||||
schur::Sort sort, ::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<dtype> schur_vectors,
|
||||
::xla::ffi::ResultBuffer<dtype> eigvals,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> selected_eigvals,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_cols,
|
||||
schur::ComputationMode mode,
|
||||
schur::Sort sort);
|
||||
};
|
||||
|
||||
//== Hessenberg Decomposition ==//
|
||||
//== Reduces a non-symmetric square matrix to upper Hessenberg form ==//
|
||||
|
||||
@ -677,6 +748,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssytrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsytrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_chetrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zhetrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgees_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgees_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgees_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgees_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi);
|
||||
|
@ -66,10 +66,10 @@ jax::EigenvalueDecomposition<ffi::DataType::F64>::FnType dgeev_;
|
||||
jax::EigenvalueDecompositionComplex<ffi::DataType::C64>::FnType cgeev_;
|
||||
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType zgeev_;
|
||||
|
||||
jax::RealGees<float>::FnType sgees_;
|
||||
jax::RealGees<double>::FnType dgees_;
|
||||
jax::ComplexGees<std::complex<float>>::FnType cgees_;
|
||||
jax::ComplexGees<std::complex<double>>::FnType zgees_;
|
||||
jax::SchurDecomposition<ffi::DataType::F32>::FnType sgees_;
|
||||
jax::SchurDecomposition<ffi::DataType::F64>::FnType dgees_;
|
||||
jax::SchurDecompositionComplex<ffi::DataType::C64>::FnType cgees_;
|
||||
jax::SchurDecompositionComplex<ffi::DataType::C128>::FnType zgees_;
|
||||
|
||||
jax::HessenbergDecomposition<ffi::DataType::F32>::FnType sgehrd_;
|
||||
jax::HessenbergDecomposition<ffi::DataType::F64>::FnType dgehrd_;
|
||||
@ -227,6 +227,22 @@ static_assert(
|
||||
std::is_same_v<jax::TridiagonalReduction<ffi::DataType::C128>::FnType,
|
||||
jax::Sytrd<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::SchurDecomposition<ffi::DataType::F32>::FnType,
|
||||
jax::RealGees<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::SchurDecomposition<ffi::DataType::F64>::FnType,
|
||||
jax::RealGees<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::SchurDecompositionComplex<ffi::DataType::C64>::FnType,
|
||||
jax::ComplexGees<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::SchurDecompositionComplex<ffi::DataType::C128>::FnType,
|
||||
jax::ComplexGees<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::HessenbergDecomposition<ffi::DataType::F32>::FnType,
|
||||
jax::Gehrd<float>::FnType>,
|
||||
@ -352,6 +368,11 @@ static auto init = []() -> int {
|
||||
AssignKernelFn<TridiagonalReduction<ffi::DataType::C64>>(chetrd_);
|
||||
AssignKernelFn<TridiagonalReduction<ffi::DataType::C128>>(zhetrd_);
|
||||
|
||||
AssignKernelFn<SchurDecomposition<ffi::DataType::F32>>(sgees_);
|
||||
AssignKernelFn<SchurDecomposition<ffi::DataType::F64>>(dgees_);
|
||||
AssignKernelFn<SchurDecompositionComplex<ffi::DataType::C64>>(cgees_);
|
||||
AssignKernelFn<SchurDecompositionComplex<ffi::DataType::C128>>(zgees_);
|
||||
|
||||
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F32>>(sgehrd_);
|
||||
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F64>>(dgehrd_);
|
||||
AssignKernelFn<HessenbergDecomposition<ffi::DataType::C64>>(cgehrd_);
|
||||
|
Loading…
x
Reference in New Issue
Block a user