mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Port LU Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 644845277
This commit is contained in:
parent
2ac1cfada9
commit
63aab133f1
@ -118,6 +118,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
|
||||
// FFI Kernels
|
||||
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgetrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi);
|
||||
|
@ -55,6 +55,10 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<Getrf<double>>(lapack_ptr("dgetrf"));
|
||||
AssignKernelFn<Getrf<std::complex<float>>>(lapack_ptr("cgetrf"));
|
||||
AssignKernelFn<Getrf<std::complex<double>>>(lapack_ptr("zgetrf"));
|
||||
AssignKernelFn<LuDecomposition<DataType::F32>>(lapack_ptr("sgetrf"));
|
||||
AssignKernelFn<LuDecomposition<DataType::F64>>(lapack_ptr("dgetrf"));
|
||||
AssignKernelFn<LuDecomposition<DataType::C64>>(lapack_ptr("cgetrf"));
|
||||
AssignKernelFn<LuDecomposition<DataType::C128>>(lapack_ptr("zgetrf"));
|
||||
|
||||
AssignKernelFn<Geqrf<float>>(lapack_ptr("sgeqrf"));
|
||||
AssignKernelFn<Geqrf<double>>(lapack_ptr("dgeqrf"));
|
||||
@ -178,6 +182,10 @@ nb::dict Registrations() {
|
||||
dict["lapack_zhetrd"] =
|
||||
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
|
||||
|
||||
dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi);
|
||||
dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi);
|
||||
dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi);
|
||||
dict["lapack_zgetrf_ffi"] = EncapsulateFunction(lapack_zgetrf_ffi);
|
||||
dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi);
|
||||
dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi);
|
||||
dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi);
|
||||
|
@ -23,6 +23,15 @@ namespace jax {
|
||||
|
||||
// FFI Definition Macros (by DataType)
|
||||
|
||||
#define JAX_CPU_DEFINE_GETRF(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER( \
|
||||
name, LuDecomposition<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*ipiv*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_POTRF(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER( \
|
||||
name, CholeskyFactorization<data_type>::Kernel, \
|
||||
@ -34,11 +43,17 @@ namespace jax {
|
||||
|
||||
// FFI Handlers
|
||||
|
||||
JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
#undef JAX_CPU_DEFINE_GETRF
|
||||
#undef JAX_CPU_DEFINE_POTRF
|
||||
|
||||
} // namespace jax
|
||||
|
@ -194,6 +194,41 @@ template struct Getrf<double>;
|
||||
template struct Getrf<std::complex<float>>;
|
||||
template struct Getrf<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error LuDecomposition<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<LapackIntDtype> ipiv,
|
||||
ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
|
||||
auto* x_out_data = x_out->data;
|
||||
auto* ipiv_data = ipiv->data;
|
||||
auto* info_data = info->data;
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto x_rows_v = CastNoOverflow<lapack_int>(x_rows);
|
||||
auto x_cols_v = CastNoOverflow<lapack_int>(x_cols);
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
const int64_t ipiv_step{std::min(x_rows, x_cols)};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, ipiv_data,
|
||||
info_data);
|
||||
x_out_data += x_out_step;
|
||||
ipiv_data += ipiv_step;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template struct LuDecomposition<ffi::DataType::F32>;
|
||||
template struct LuDecomposition<ffi::DataType::F64>;
|
||||
template struct LuDecomposition<ffi::DataType::C64>;
|
||||
template struct LuDecomposition<ffi::DataType::C128>;
|
||||
|
||||
//== QR Factorization ==//
|
||||
|
||||
// lapack geqrf
|
||||
|
@ -104,6 +104,21 @@ struct Getrf {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct LuDecomposition {
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using FnType = void(lapack_int* m, lapack_int* n, ValueType* a,
|
||||
lapack_int* lda, lapack_int* ipiv, lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> ipiv,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
};
|
||||
|
||||
//== QR Factorization ==//
|
||||
|
||||
// lapack geqrf
|
||||
|
@ -31,10 +31,10 @@ jax::Trsm<double>::FnType dtrsm_;
|
||||
jax::Trsm<std::complex<float>>::FnType ctrsm_;
|
||||
jax::Trsm<std::complex<double>>::FnType ztrsm_;
|
||||
|
||||
jax::Getrf<float>::FnType sgetrf_;
|
||||
jax::Getrf<double>::FnType dgetrf_;
|
||||
jax::Getrf<std::complex<float>>::FnType cgetrf_;
|
||||
jax::Getrf<std::complex<double>>::FnType zgetrf_;
|
||||
jax::LuDecomposition<ffi::DataType::F32>::FnType sgetrf_;
|
||||
jax::LuDecomposition<ffi::DataType::F64>::FnType dgetrf_;
|
||||
jax::LuDecomposition<ffi::DataType::C64>::FnType cgetrf_;
|
||||
jax::LuDecomposition<ffi::DataType::C128>::FnType zgetrf_;
|
||||
|
||||
jax::Geqrf<float>::FnType sgeqrf_;
|
||||
jax::Geqrf<double>::FnType dgeqrf_;
|
||||
@ -87,6 +87,18 @@ namespace jax {
|
||||
|
||||
#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch"
|
||||
|
||||
static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::F32>::FnType,
|
||||
jax::Getrf<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::F64>::FnType,
|
||||
jax::Getrf<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::C64>::FnType,
|
||||
jax::Getrf<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::C128>::FnType,
|
||||
jax::Getrf<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::F32>::FnType,
|
||||
jax::Potrf<float>::FnType>,
|
||||
@ -164,6 +176,11 @@ static auto init = []() -> int {
|
||||
|
||||
// FFI Kernels
|
||||
|
||||
AssignKernelFn<LuDecomposition<ffi::DataType::F32>>(sgetrf_);
|
||||
AssignKernelFn<LuDecomposition<ffi::DataType::F64>>(dgetrf_);
|
||||
AssignKernelFn<LuDecomposition<ffi::DataType::C64>>(cgetrf_);
|
||||
AssignKernelFn<LuDecomposition<ffi::DataType::C128>>(zgetrf_);
|
||||
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::F32>>(spotrf_);
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::F64>>(dpotrf_);
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::C64>>(cpotrf_);
|
||||
|
Loading…
x
Reference in New Issue
Block a user