mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Port Tridiagonal Reduction to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 685679646
This commit is contained in:
parent
b6f38bcc4b
commit
ec68d420fe
@ -149,6 +149,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_ssytrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dsytrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_chetrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zhetrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi);
|
||||
|
@ -151,6 +151,10 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<Sytrd<double>>(lapack_ptr("dsytrd"));
|
||||
AssignKernelFn<Sytrd<std::complex<float>>>(lapack_ptr("chetrd"));
|
||||
AssignKernelFn<Sytrd<std::complex<double>>>(lapack_ptr("zhetrd"));
|
||||
AssignKernelFn<TridiagonalReduction<DataType::F32>>(lapack_ptr("ssytrd"));
|
||||
AssignKernelFn<TridiagonalReduction<DataType::F64>>(lapack_ptr("dsytrd"));
|
||||
AssignKernelFn<TridiagonalReduction<DataType::C64>>(lapack_ptr("chetrd"));
|
||||
AssignKernelFn<TridiagonalReduction<DataType::C128>>(lapack_ptr("zhetrd"));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
@ -257,6 +261,10 @@ nb::dict Registrations() {
|
||||
dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi);
|
||||
dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi);
|
||||
dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi);
|
||||
dict["lapack_ssytrd_ffi"] = EncapsulateFunction(lapack_ssytrd_ffi);
|
||||
dict["lapack_dsytrd_ffi"] = EncapsulateFunction(lapack_dsytrd_ffi);
|
||||
dict["lapack_chetrd_ffi"] = EncapsulateFunction(lapack_chetrd_ffi);
|
||||
dict["lapack_zhetrd_ffi"] = EncapsulateFunction(lapack_zhetrd_ffi);
|
||||
dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi);
|
||||
dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi);
|
||||
dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi);
|
||||
|
@ -1741,6 +1741,68 @@ template struct Sytrd<double>;
|
||||
template struct Sytrd<std::complex<float>>;
|
||||
template struct Sytrd<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error TridiagonalReduction<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> diagonal,
|
||||
ffi::ResultBuffer<ffi::ToReal(dtype)> off_diagonal,
|
||||
ffi::ResultBuffer<dtype> tau, ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
ValueType* x_out_data = x_out->typed_data();
|
||||
RealType* diagonal_data = diagonal->typed_data();
|
||||
RealType* off_diagonal_data = off_diagonal->typed_data();
|
||||
ValueType* tau_data = tau->typed_data();
|
||||
lapack_int* info_data = info->typed_data();
|
||||
|
||||
// Prepare LAPACK workspaces.
|
||||
const auto work_size = GetWorkspaceSize(x_rows, x_cols);
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size);
|
||||
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto work_size_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_order_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
|
||||
int64_t x_size = x_rows * x_cols;
|
||||
int64_t tau_step = {tau->dimensions().back()};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, diagonal_data,
|
||||
off_diagonal_data, tau_data, work_data.get(), &work_size_v, info_data);
|
||||
x_out_data += x_size;
|
||||
diagonal_data += x_cols;
|
||||
off_diagonal_data += x_cols - 1;
|
||||
tau_data += tau_step;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
int64_t TridiagonalReduction<dtype>::GetWorkspaceSize(lapack_int x_rows,
|
||||
lapack_int x_cols) {
|
||||
ValueType optimal_size = {};
|
||||
lapack_int workspace_query = -1;
|
||||
lapack_int info = 0;
|
||||
char uplo_v = 'L';
|
||||
fn(&uplo_v, &x_cols, nullptr, &x_rows, nullptr, nullptr, nullptr,
|
||||
&optimal_size, &workspace_query, &info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
||||
}
|
||||
|
||||
template struct TridiagonalReduction<ffi::DataType::F32>;
|
||||
template struct TridiagonalReduction<ffi::DataType::F64>;
|
||||
template struct TridiagonalReduction<ffi::DataType::C64>;
|
||||
template struct TridiagonalReduction<ffi::DataType::C128>;
|
||||
|
||||
// FFI Definition Macros (by DataType)
|
||||
|
||||
#define JAX_CPU_DEFINE_TRSM(name, data_type) \
|
||||
@ -1864,6 +1926,20 @@ template struct Sytrd<std::complex<double>>;
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*eigvecs_right*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_SYTRD_HETRD(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, TridiagonalReduction<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<MatrixParams::UpLo>("uplo") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*diagonal*/) \
|
||||
.Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \
|
||||
/*off_diagonal*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEHRD(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, HessenbergDecomposition<data_type>::Kernel, \
|
||||
@ -1917,6 +1993,11 @@ JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_ssytrd_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_dsytrd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_chetrd_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_SYTRD_HETRD(lapack_zhetrd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64);
|
||||
@ -1933,6 +2014,7 @@ JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128);
|
||||
#undef JAX_CPU_DEFINE_HEEVD
|
||||
#undef JAX_CPU_DEFINE_GEEV
|
||||
#undef JAX_CPU_DEFINE_GEEV_COMPLEX
|
||||
#undef JAX_CPU_DEFINE_SYTRD_HETRD
|
||||
#undef JAX_CPU_DEFINE_GEHRD
|
||||
|
||||
} // namespace jax
|
||||
|
@ -617,6 +617,29 @@ struct Sytrd {
|
||||
static int64_t Workspace(lapack_int lda, lapack_int n);
|
||||
};
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct TridiagonalReduction {
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using RealType = ::xla::ffi::NativeType<::xla::ffi::ToReal(dtype)>;
|
||||
using FnType = void(char* uplo, lapack_int* n, ValueType* a, lapack_int* lda,
|
||||
RealType* d, RealType* e, ValueType* tau, ValueType* work,
|
||||
lapack_int* lwork, lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> diagonal,
|
||||
::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> off_diagonal,
|
||||
::xla::ffi::ResultBuffer<dtype> tau,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols);
|
||||
};
|
||||
|
||||
// Declare all the handler symbols
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_strsm_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_dtrsm_ffi);
|
||||
@ -650,6 +673,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ssytrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dsytrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_chetrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zhetrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi);
|
||||
|
@ -76,10 +76,10 @@ jax::HessenbergDecomposition<ffi::DataType::F64>::FnType dgehrd_;
|
||||
jax::HessenbergDecomposition<ffi::DataType::C64>::FnType cgehrd_;
|
||||
jax::HessenbergDecomposition<ffi::DataType::C128>::FnType zgehrd_;
|
||||
|
||||
jax::Sytrd<float>::FnType ssytrd_;
|
||||
jax::Sytrd<double>::FnType dsytrd_;
|
||||
jax::Sytrd<std::complex<float>>::FnType chetrd_;
|
||||
jax::Sytrd<std::complex<double>>::FnType zhetrd_;
|
||||
jax::TridiagonalReduction<ffi::DataType::F32>::FnType ssytrd_;
|
||||
jax::TridiagonalReduction<ffi::DataType::F64>::FnType dsytrd_;
|
||||
jax::TridiagonalReduction<ffi::DataType::C64>::FnType chetrd_;
|
||||
jax::TridiagonalReduction<ffi::DataType::C128>::FnType zhetrd_;
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@ -211,6 +211,22 @@ static_assert(
|
||||
jax::EigenvalueDecompositionComplex<ffi::DataType::C128>::FnType,
|
||||
jax::ComplexGeev<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::TridiagonalReduction<ffi::DataType::F32>::FnType,
|
||||
jax::Sytrd<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::TridiagonalReduction<ffi::DataType::F64>::FnType,
|
||||
jax::Sytrd<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::TridiagonalReduction<ffi::DataType::C64>::FnType,
|
||||
jax::Sytrd<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::TridiagonalReduction<ffi::DataType::C128>::FnType,
|
||||
jax::Sytrd<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::HessenbergDecomposition<ffi::DataType::F32>::FnType,
|
||||
jax::Gehrd<float>::FnType>,
|
||||
@ -331,6 +347,11 @@ static auto init = []() -> int {
|
||||
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C64>>(cgeev_);
|
||||
AssignKernelFn<EigenvalueDecompositionComplex<ffi::DataType::C128>>(zgeev_);
|
||||
|
||||
AssignKernelFn<TridiagonalReduction<ffi::DataType::F32>>(ssytrd_);
|
||||
AssignKernelFn<TridiagonalReduction<ffi::DataType::F64>>(dsytrd_);
|
||||
AssignKernelFn<TridiagonalReduction<ffi::DataType::C64>>(chetrd_);
|
||||
AssignKernelFn<TridiagonalReduction<ffi::DataType::C128>>(zhetrd_);
|
||||
|
||||
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F32>>(sgehrd_);
|
||||
AssignKernelFn<HessenbergDecomposition<ffi::DataType::F64>>(dgehrd_);
|
||||
AssignKernelFn<HessenbergDecomposition<ffi::DataType::C64>>(cgehrd_);
|
||||
|
Loading…
x
Reference in New Issue
Block a user