From 54fe6e68a0cc1c0cb42a6910a8024d7db9dcb05e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Wed, 24 Jul 2024 02:14:40 -0700 Subject: [PATCH] Port Triangular Solve to XLA's FFI This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 655484166 --- jaxlib/cpu/cpu_kernels.cc | 4 +++ jaxlib/cpu/lapack.cc | 8 +++++ jaxlib/cpu/lapack.h | 18 ++++++++++ jaxlib/cpu/lapack_kernels.cc | 42 +++++++++++++++++++++++ jaxlib/cpu/lapack_kernels.h | 18 ++++++++++ jaxlib/cpu/lapack_kernels_using_lapack.cc | 29 +++++++++++++--- 6 files changed, 115 insertions(+), 4 deletions(-) diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 63bd7ac07..c7bc9dc4b 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -118,6 +118,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( // FFI Kernels +JAX_CPU_REGISTER_HANDLER(blas_strsm_ffi); +JAX_CPU_REGISTER_HANDLER(blas_dtrsm_ffi); +JAX_CPU_REGISTER_HANDLER(blas_ctrsm_ffi); +JAX_CPU_REGISTER_HANDLER(blas_ztrsm_ffi); JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi); diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 602e4dd4e..0aa924486 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -69,6 +69,10 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(blas_ptr("dtrsm")); AssignKernelFn>>(blas_ptr("ctrsm")); AssignKernelFn>>(blas_ptr("ztrsm")); + AssignKernelFn>(blas_ptr("strsm")); + AssignKernelFn>(blas_ptr("dtrsm")); + AssignKernelFn>(blas_ptr("ctrsm")); + AssignKernelFn>(blas_ptr("ztrsm")); nb::module_ cython_lapack = nb::module_::import_("scipy.linalg.cython_lapack"); @@ -219,6 +223,10 @@ nb::dict Registrations() { dict["lapack_zhetrd"] = EncapsulateFunction(Sytrd>::Kernel); + dict["blas_strsm_ffi"] = EncapsulateFunction(blas_strsm_ffi); + dict["blas_dtrsm_ffi"] = EncapsulateFunction(blas_dtrsm_ffi); + dict["blas_ctrsm_ffi"] = EncapsulateFunction(blas_ctrsm_ffi); + dict["blas_ztrsm_ffi"] = EncapsulateFunction(blas_ztrsm_ffi); dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi); dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi); dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi); diff --git a/jaxlib/cpu/lapack.h b/jaxlib/cpu/lapack.h index 828a147d7..81a5acce5 100644 --- a/jaxlib/cpu/lapack.h +++ b/jaxlib/cpu/lapack.h @@ -23,6 +23,18 @@ namespace jax { // FFI Definition Macros (by DataType) +#define JAX_CPU_DEFINE_TRSM(name, data_type) \ + XLA_FFI_DEFINE_HANDLER(name, TriMatrixEquationSolver::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*y*/) \ + .Arg<::xla::ffi::BufferR0>(/*alpha*/) \ + .Ret<::xla::ffi::Buffer>(/*y_out*/) \ + .Attr("side") \ + .Attr("uplo") \ + .Attr("trans_x") \ + .Attr("diag")) + #define JAX_CPU_DEFINE_GETRF(name, data_type) \ XLA_FFI_DEFINE_HANDLER( \ name, LuDecomposition::Kernel, \ @@ -92,6 +104,11 @@ namespace jax { // FFI Handlers +JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128); + JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32); JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64); @@ -117,6 +134,7 @@ JAX_CPU_DEFINE_GESDD(lapack_dgesdd_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128); +#undef JAX_CPU_DEFINE_TRSM #undef JAX_CPU_DEFINE_GETRF #undef JAX_CPU_DEFINE_GEQRF #undef JAX_CPU_DEFINE_ORGQR diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 08a0537eb..85fa02893 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -160,6 +160,48 @@ template struct Trsm; template struct Trsm>; template struct Trsm>; +// FFI Kernel + +template +ffi::Error TriMatrixEquationSolver::Kernel( + ffi::Buffer x, ffi::Buffer y, ffi::BufferR0 alpha, + ffi::ResultBuffer y_out, MatrixParams::Side side, + MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, + MatrixParams::Diag diag) { + CopyIfDiffBuffer(y, y_out); + + auto [batch_count, y_rows, y_cols] = SplitBatch2D(y.dimensions()); + auto* y_out_data = y_out->typed_data(); + lapack_int x_leading_dim_v = + side == MatrixParams::Side::kLeft ? y_rows : y_cols; + lapack_int y_leading_dim_v = y_rows; + + auto side_v = static_cast(side); + auto uplo_v = static_cast(uplo); + auto trans_x_v = static_cast(trans_x); + auto diag_v = static_cast(diag); + FFI_ASSIGN_OR_RETURN(auto y_rows_v, MaybeCastNoOverflow(y_rows)); + FFI_ASSIGN_OR_RETURN(auto y_cols_v, MaybeCastNoOverflow(y_cols)); + + auto* x_data = x.typed_data(); + const int64_t y_out_step{y_rows * y_cols}; + const int64_t x_step{x_leading_dim_v * x_leading_dim_v}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&side_v, &uplo_v, &trans_x_v, &diag_v, &y_rows_v, &y_cols_v, + alpha.typed_data(), x_data, &x_leading_dim_v, y_out_data, + &y_leading_dim_v); + + y_out_data += y_out_step; + x_data += x_step; + } + return ffi::Error::Success(); +} + +template struct TriMatrixEquationSolver; +template struct TriMatrixEquationSolver; +template struct TriMatrixEquationSolver; +template struct TriMatrixEquationSolver; + //== LU Decomposition ==// // lapack getrf diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index ea67bd32c..fd9d8c975 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -109,6 +109,24 @@ struct Trsm { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct TriMatrixEquationSolver { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(char* side, char* uplo, char* transa, char* diag, + lapack_int* m, lapack_int* n, ValueType* alpha, + ValueType* a, lapack_int* lda, ValueType* b, + lapack_int* ldb); + + inline static FnType* fn = nullptr; + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, ::xla::ffi::Buffer y, + ::xla::ffi::BufferR0 alpha, ::xla::ffi::ResultBuffer y_out, + MatrixParams::Side side, MatrixParams::UpLo uplo, + MatrixParams::Transpose trans_x, MatrixParams::Diag diag); +}; + //== LU Decomposition ==// // lapack getrf diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 9c8a7b61e..d17925e13 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -26,10 +26,10 @@ namespace ffi = xla::ffi; extern "C" { -jax::Trsm::FnType strsm_; -jax::Trsm::FnType dtrsm_; -jax::Trsm>::FnType ctrsm_; -jax::Trsm>::FnType ztrsm_; +jax::TriMatrixEquationSolver::FnType strsm_; +jax::TriMatrixEquationSolver::FnType dtrsm_; +jax::TriMatrixEquationSolver::FnType ctrsm_; +jax::TriMatrixEquationSolver::FnType ztrsm_; jax::LuDecomposition::FnType sgetrf_; jax::LuDecomposition::FnType dgetrf_; @@ -87,6 +87,22 @@ namespace jax { #define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" +static_assert( + std::is_same_v::FnType, + jax::Trsm::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Trsm::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Trsm>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Trsm>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); static_assert(std::is_same_v::FnType, jax::Getrf::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); @@ -218,6 +234,11 @@ static auto init = []() -> int { // FFI Kernels + AssignKernelFn>(strsm_); + AssignKernelFn>(dtrsm_); + AssignKernelFn>(ctrsm_); + AssignKernelFn>(ztrsm_); + AssignKernelFn>(sgetrf_); AssignKernelFn>(dgetrf_); AssignKernelFn>(cgetrf_);