Port Hessenberg Decomposition to XLA's FFI

This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 671283487
This commit is contained in:
Paweł Paruzel 2024-09-05 01:58:58 -07:00 committed by jax authors
parent 720dfd7e43
commit 2082662bb1
5 changed files with 136 additions and 9 deletions

View File

@ -149,6 +149,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgehrd_ffi);
#undef JAX_CPU_REGISTER_HANDLER

View File

@ -142,6 +142,10 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<Gehrd<double>>(lapack_ptr("dgehrd"));
AssignKernelFn<Gehrd<std::complex<float>>>(lapack_ptr("cgehrd"));
AssignKernelFn<Gehrd<std::complex<double>>>(lapack_ptr("zgehrd"));
AssignKernelFn<HessenbergDecomposition<DataType::F32>>(lapack_ptr("sgehrd"));
AssignKernelFn<HessenbergDecomposition<DataType::F64>>(lapack_ptr("dgehrd"));
AssignKernelFn<HessenbergDecomposition<DataType::C64>>(lapack_ptr("cgehrd"));
AssignKernelFn<HessenbergDecomposition<DataType::C128>>(lapack_ptr("zgehrd"));
AssignKernelFn<Sytrd<float>>(lapack_ptr("ssytrd"));
AssignKernelFn<Sytrd<double>>(lapack_ptr("dsytrd"));
@ -253,6 +257,10 @@ nb::dict Registrations() {
dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi);
dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi);
dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi);
dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi);
dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi);
dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi);
dict["lapack_zgehrd_ffi"] = EncapsulateFunction(lapack_zgehrd_ffi);
return dict;
}

View File

@ -1627,6 +1627,59 @@ template struct Gehrd<double>;
template struct Gehrd<std::complex<float>>;
template struct Gehrd<std::complex<double>>;
// FFI Kernel
template <ffi::DataType dtype>
ffi::Error HessenbergDecomposition<dtype>::Kernel(
ffi::Buffer<dtype> x, lapack_int low, lapack_int high,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<dtype> tau,
ffi::ResultBuffer<LapackIntDtype> info) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
CopyIfDiffBuffer(x, x_out);
ValueType* x_out_data = x_out->typed_data();
ValueType* tau_data = tau->typed_data();
lapack_int* info_data = info->typed_data();
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(x_rows));
// Prepare LAPACK workspaces.
int64_t work_size = GetWorkspaceSize(x_rows, x_cols, low, high);
FFI_ASSIGN_OR_RETURN(auto work_size_v,
MaybeCastNoOverflow<lapack_int>(work_size));
auto work_data = AllocateScratchMemory<dtype>(work_size);
int64_t x_size{x_rows * x_cols};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&x_cols_v, &low, &high, x_out_data, &x_leading_dim_v, tau_data,
work_data.get(), &work_size_v, info_data);
x_out_data += x_size;
tau_data += x_cols - 1;
++info_data;
}
return ffi::Error::Success();
}
template <ffi::DataType dtype>
int64_t HessenbergDecomposition<dtype>::GetWorkspaceSize(lapack_int x_rows,
lapack_int x_cols,
lapack_int low,
lapack_int high) {
ValueType optimal_size = {};
lapack_int workspace_query = -1;
lapack_int info = 0;
fn(&x_cols, &low, &high, nullptr, &x_rows, nullptr, &optimal_size,
&workspace_query, &info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
}
template struct HessenbergDecomposition<ffi::DataType::F32>;
template struct HessenbergDecomposition<ffi::DataType::F64>;
template struct HessenbergDecomposition<ffi::DataType::C64>;
template struct HessenbergDecomposition<ffi::DataType::C128>;
//== Tridiagonal Reduction ==//
// lapack sytrd/hetrd
@ -1811,6 +1864,17 @@ template struct Sytrd<std::complex<double>>;
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_GEHRD(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, HessenbergDecomposition<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<lapack_int>("low") \
.Attr<lapack_int>("high") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
// FFI Handlers
JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
@ -1853,6 +1917,11 @@ JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_TRSM
#undef JAX_CPU_DEFINE_GETRF
#undef JAX_CPU_DEFINE_GEQRF
@ -1864,5 +1933,6 @@ JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_HEEVD
#undef JAX_CPU_DEFINE_GEEV
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
#undef JAX_CPU_DEFINE_GEHRD
} // namespace jax

View File

@ -192,8 +192,8 @@ struct QrFactorization {
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
static ::xla::ffi::Error Kernel(::xla::ffi::Buffer<dtype> x,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> tau);
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols);
@ -444,8 +444,7 @@ struct EigenvalueDecompositionHermitian {
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
eig::ComputationMode mode);
::xla::ffi::ResultBuffer<LapackIntDtype> info, eig::ComputationMode mode);
};
// lapack geev
@ -579,6 +578,27 @@ struct real_type<std::complex<T>> {
typedef T type;
};
// FFI Kernel
template <::xla::ffi::DataType dtype>
struct HessenbergDecomposition {
using ValueType = ::xla::ffi::NativeType<dtype>;
using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi,
ValueType* a, lapack_int* lda, ValueType* tau,
ValueType* work, lapack_int* lwork, lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, lapack_int low, lapack_int high,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> tau,
::xla::ffi::ResultBuffer<LapackIntDtype> info);
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols,
lapack_int low, lapack_int high);
};
//== Tridiagonal Reduction ==//
//== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==//
@ -630,6 +650,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgehrd_ffi);
} // namespace jax

View File

@ -71,10 +71,10 @@ jax::RealGees<double>::FnType dgees_;
jax::ComplexGees<std::complex<float>>::FnType cgees_;
jax::ComplexGees<std::complex<double>>::FnType zgees_;
jax::Gehrd<float>::FnType sgehrd_;
jax::Gehrd<double>::FnType dgehrd_;
jax::Gehrd<std::complex<float>>::FnType cgehrd_;
jax::Gehrd<std::complex<double>>::FnType zgehrd_;
jax::HessenbergDecomposition<ffi::DataType::F32>::FnType sgehrd_;
jax::HessenbergDecomposition<ffi::DataType::F64>::FnType dgehrd_;
jax::HessenbergDecomposition<ffi::DataType::C64>::FnType cgehrd_;
jax::HessenbergDecomposition<ffi::DataType::C128>::FnType zgehrd_;
jax::Sytrd<float>::FnType ssytrd_;
jax::Sytrd<double>::FnType dsytrd_;
@ -211,6 +211,22 @@ static_assert(
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType,
jax::ComplexGeev<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::HessenbergDecomposition<ffi::DataType::F32>::FnType,
jax::Gehrd<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::HessenbergDecomposition<ffi::DataType::F64>::FnType,
jax::Gehrd<double>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::HessenbergDecomposition<ffi::DataType::C64>::FnType,
jax::Gehrd<std::complex<float>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::HessenbergDecomposition<ffi::DataType::C128>::FnType,
jax::Gehrd<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG
@ -315,6 +331,11 @@ static auto init = []() -> int {
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C64>>(cgeev_);
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C128>>(zgeev_);
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F32>>(sgehrd_);
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F64>>(dgehrd_);
AssignKernelFn<HessenbergDecomposition<ffi::DataType::C64>>(cgehrd_);
AssignKernelFn<HessenbergDecomposition<ffi::DataType::C128>>(zgehrd_);
return 0;
}();