From 63aab133f1b6943fcf660bc2fc31f0fe8209ee0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Wed, 19 Jun 2024 17:30:50 -0700 Subject: [PATCH] Port LU Decomposition to XLA's FFI This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 644845277 --- jaxlib/cpu/cpu_kernels.cc | 4 +++ jaxlib/cpu/lapack.cc | 8 ++++++ jaxlib/cpu/lapack.h | 15 ++++++++++ jaxlib/cpu/lapack_kernels.cc | 35 +++++++++++++++++++++++ jaxlib/cpu/lapack_kernels.h | 15 ++++++++++ jaxlib/cpu/lapack_kernels_using_lapack.cc | 25 +++++++++++++--- 6 files changed, 98 insertions(+), 4 deletions(-) diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 774d03aa8..0cb9e7cb3 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -118,6 +118,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( // FFI Kernels +JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgetrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi); diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 540e90f9e..d01efa7f7 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -55,6 +55,10 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dgetrf")); AssignKernelFn>>(lapack_ptr("cgetrf")); AssignKernelFn>>(lapack_ptr("zgetrf")); + AssignKernelFn>(lapack_ptr("sgetrf")); + AssignKernelFn>(lapack_ptr("dgetrf")); + AssignKernelFn>(lapack_ptr("cgetrf")); + AssignKernelFn>(lapack_ptr("zgetrf")); AssignKernelFn>(lapack_ptr("sgeqrf")); AssignKernelFn>(lapack_ptr("dgeqrf")); @@ -178,6 +182,10 @@ nb::dict Registrations() { dict["lapack_zhetrd"] = EncapsulateFunction(Sytrd>::Kernel); + dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi); + dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi); + dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi); + dict["lapack_zgetrf_ffi"] = EncapsulateFunction(lapack_zgetrf_ffi); dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi); dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi); dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi); diff --git a/jaxlib/cpu/lapack.h b/jaxlib/cpu/lapack.h index 0b59c729a..b00440616 100644 --- a/jaxlib/cpu/lapack.h +++ b/jaxlib/cpu/lapack.h @@ -23,6 +23,15 @@ namespace jax { // FFI Definition Macros (by DataType) +#define JAX_CPU_DEFINE_GETRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER( \ + name, LuDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*ipiv*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + #define JAX_CPU_DEFINE_POTRF(name, data_type) \ XLA_FFI_DEFINE_HANDLER( \ name, CholeskyFactorization::Kernel, \ @@ -34,11 +43,17 @@ namespace jax { // FFI Handlers +JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128); + JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32); JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128); +#undef JAX_CPU_DEFINE_GETRF #undef JAX_CPU_DEFINE_POTRF } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 57d078f21..247d5e963 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -194,6 +194,41 @@ template struct Getrf; template struct Getrf>; template struct Getrf>; +// FFI Kernel + +template +ffi::Error LuDecomposition::Kernel( + ffi::Buffer x, ffi::ResultBuffer x_out, + ffi::ResultBuffer ipiv, + ffi::ResultBuffer info) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + auto* x_out_data = x_out->data; + auto* ipiv_data = ipiv->data; + auto* info_data = info->data; + + CopyIfDiffBuffer(x, x_out); + + auto x_rows_v = CastNoOverflow(x_rows); + auto x_cols_v = CastNoOverflow(x_cols); + auto x_leading_dim_v = x_rows_v; + + const int64_t x_out_step{x_rows * x_cols}; + const int64_t ipiv_step{std::min(x_rows, x_cols)}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, ipiv_data, + info_data); + x_out_data += x_out_step; + ipiv_data += ipiv_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template struct LuDecomposition; +template struct LuDecomposition; +template struct LuDecomposition; +template struct LuDecomposition; + //== QR Factorization ==// // lapack geqrf diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 86bfc8e94..4119f6ba0 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -104,6 +104,21 @@ struct Getrf { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct LuDecomposition { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(lapack_int* m, lapack_int* n, ValueType* a, + lapack_int* lda, lapack_int* ipiv, lapack_int* info); + + inline static FnType* fn = nullptr; + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer ipiv, + ::xla::ffi::ResultBuffer info); +}; + //== QR Factorization ==// // lapack geqrf diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 8d69e93f9..48b1d5bff 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -31,10 +31,10 @@ jax::Trsm::FnType dtrsm_; jax::Trsm>::FnType ctrsm_; jax::Trsm>::FnType ztrsm_; -jax::Getrf::FnType sgetrf_; -jax::Getrf::FnType dgetrf_; -jax::Getrf>::FnType cgetrf_; -jax::Getrf>::FnType zgetrf_; +jax::LuDecomposition::FnType sgetrf_; +jax::LuDecomposition::FnType dgetrf_; +jax::LuDecomposition::FnType cgetrf_; +jax::LuDecomposition::FnType zgetrf_; jax::Geqrf::FnType sgeqrf_; jax::Geqrf::FnType dgeqrf_; @@ -87,6 +87,18 @@ namespace jax { #define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" +static_assert(std::is_same_v::FnType, + jax::Getrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert(std::is_same_v::FnType, + jax::Getrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); static_assert( std::is_same_v::FnType, jax::Potrf::FnType>, @@ -164,6 +176,11 @@ static auto init = []() -> int { // FFI Kernels + AssignKernelFn>(sgetrf_); + AssignKernelFn>(dgetrf_); + AssignKernelFn>(cgetrf_); + AssignKernelFn>(zgetrf_); + AssignKernelFn>(spotrf_); AssignKernelFn>(dpotrf_); AssignKernelFn>(cpotrf_);