From 3d39b6e752ac21ce984c17d5bae60a3e1857695b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Thu, 13 Jun 2024 05:43:40 -0700 Subject: [PATCH] Port Cholesky Factorization to XLA's FFI This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 642954763 --- jaxlib/cpu/BUILD | 9 ++ jaxlib/cpu/cpu_kernels.cc | 20 +++ jaxlib/cpu/lapack.cc | 16 +- jaxlib/cpu/lapack.h | 46 ++++++ jaxlib/cpu/lapack_kernels.cc | 169 +++++++++++++++++++--- jaxlib/cpu/lapack_kernels.h | 107 +++++++++++++- jaxlib/cpu/lapack_kernels_using_lapack.cc | 41 +++++- 7 files changed, 373 insertions(+), 35 deletions(-) create mode 100644 jaxlib/cpu/lapack.h diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 7d0293571..d21c50101 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -35,8 +35,12 @@ cc_library( copts = ["-fexceptions"], features = ["-use_header_modules"], deps = [ + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/strings:str_format", ], ) @@ -50,6 +54,7 @@ cc_library( pybind_extension( name = "_lapack", srcs = ["lapack.cc"], + hdrs = ["lapack.h"], copts = [ "-fexceptions", "-fno-strict-aliasing", @@ -63,6 +68,7 @@ pybind_extension( deps = [ ":lapack_kernels", "//jaxlib:kernel_nanobind_helpers", + "@xla//xla/ffi/api:ffi", "@nanobind", ], ) @@ -70,10 +76,13 @@ pybind_extension( cc_library( name = "cpu_kernels", srcs = ["cpu_kernels.cc"], + hdrs = ["lapack.h"], visibility = ["//visibility:public"], deps = [ ":lapack_kernels", ":lapack_kernels_using_lapack", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index e1f9e4fdc..774d03aa8 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -15,13 +15,24 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. + #include + +#include "jaxlib/cpu/lapack.h" #include "jaxlib/cpu/lapack_kernels.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_target_registry.h" +#define JAX_CPU_REGISTER_HANDLER(name) \ + XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name); + namespace jax { namespace { +// Old-style kernels +// TODO(b/344892332): To be removed after the 6M compatibility period is over. + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm::Kernel, "Host"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm::Kernel, @@ -105,5 +116,14 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( "lapack_zgees", ComplexGees>::Kernel, "Host"); +// FFI Kernels + +JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zpotrf_ffi); + +#undef JAX_CPU_REGISTER_HANDLER + } // namespace } // namespace jax diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 5ef42cfc5..540e90f9e 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "jaxlib/cpu/lapack.h" + #include #include "nanobind/nanobind.h" @@ -24,6 +26,8 @@ namespace { namespace nb = nanobind; +using ::xla::ffi::DataType; + void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; @@ -66,6 +70,10 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dpotrf")); AssignKernelFn>>(lapack_ptr("cpotrf")); AssignKernelFn>>(lapack_ptr("zpotrf")); + AssignKernelFn>(lapack_ptr("spotrf")); + AssignKernelFn>(lapack_ptr("dpotrf")); + AssignKernelFn>(lapack_ptr("cpotrf")); + AssignKernelFn>(lapack_ptr("zpotrf")); AssignKernelFn>(lapack_ptr("sgesdd")); AssignKernelFn>(lapack_ptr("dgesdd")); @@ -170,14 +178,20 @@ nb::dict Registrations() { dict["lapack_zhetrd"] = EncapsulateFunction(Sytrd>::Kernel); + dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi); + dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi); + dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi); + dict["lapack_zpotrf_ffi"] = EncapsulateFunction(lapack_zpotrf_ffi); + return dict; } NB_MODULE(_lapack, m) { // Populates the LAPACK kernels from scipy on first call. m.def("initialize", GetLapackKernelsFromScipy); - m.def("registrations", &Registrations); + + // Old-style LAPACK Workspace Size Queries m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), nb::arg("n")); m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), diff --git a/jaxlib/cpu/lapack.h b/jaxlib/cpu/lapack.h new file mode 100644 index 000000000..0b59c729a --- /dev/null +++ b/jaxlib/cpu/lapack.h @@ -0,0 +1,46 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CPU_LAPACK_H_ +#define JAXLIB_CPU_LAPACK_H_ + +#include "jaxlib/cpu/lapack_kernels.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { + +// FFI Definition Macros (by DataType) + +#define JAX_CPU_DEFINE_POTRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER( \ + name, CholeskyFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) + +// FFI Handlers + +JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128); + +#undef JAX_CPU_DEFINE_POTRF + +} // namespace jax + +#endif // JAXLIB_CPU_LAPACK_H_ diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 00b54bab0..57d078f21 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -15,32 +15,98 @@ limitations under the License. #include "jaxlib/cpu/lapack_kernels.h" +#include #include +#include +#include #include -#include +#include #include +#include +#include +#include +#include +#include +#include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" +#include "absl/strings/str_format.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), + "Expected LAPACK integers to be 32-bit"); + +namespace ffi = xla::ffi; namespace { -inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) { - if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) { +template +inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { + if constexpr (sizeof(T) == sizeof(int64_t)) { return value; } else { - if (value > std::numeric_limits::max()) { - throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int"); + if (value > std::numeric_limits::max()) [[unlikely]] { + throw std::overflow_error{ + absl::StrFormat("%s: Value (=%d) exceeds the maximum representable " + "value of the desired type", + source, value)}; } - return value; + return static_cast(value); } } +template +std::tuple SplitBatch2D(ffi::Span dims) { + if (dims.size() < 2) { + throw std::invalid_argument("Matrix must have at least 2 dimensions"); + } + auto matrix_dims = dims.last(2); + return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1, + std::multiplies()), + matrix_dims.front(), matrix_dims.back()); } +template +void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + if (x.data != x_out->data) { + const auto x_size = batch_count * x_rows * x_cols; + std::copy_n(x.data, x_size, x_out->data); + } +} + +} // namespace + +#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \ + std::optional xla::ffi::AttrDecoding::Decode( \ + XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \ + if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \ + return diagnostic.Emit("Wrong attribute type: expected ") \ + << XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \ + } \ + auto* scalar = reinterpret_cast(attr); \ + if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \ + return diagnostic.Emit("Wrong scalar data type: expected ") \ + << XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \ + } \ + auto underlying = \ + *reinterpret_cast*>(scalar->value); \ + return static_cast(underlying); \ + } + +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); +REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); + +#undef REGISTER_CHAR_ENUM_ATTR_DECODING + namespace jax { -static_assert(sizeof(lapack_int) == sizeof(int32_t), - "Expected LAPACK integers to be 32-bit"); +//== Triangular System Solver ==// + +// lapack trsm template typename Trsm::FnType* Trsm::fn = nullptr; @@ -92,7 +158,9 @@ template struct Trsm; template struct Trsm>; template struct Trsm>; -// Getrf +//== LU Decomposition ==// + +// lapack getrf template typename Getrf::FnType* Getrf::fn = nullptr; @@ -126,7 +194,9 @@ template struct Getrf; template struct Getrf>; template struct Getrf>; -// Geqrf +//== QR Factorization ==// + +// lapack geqrf template typename Geqrf::FnType* Geqrf::fn = nullptr; @@ -173,7 +243,10 @@ template struct Geqrf; template struct Geqrf>; template struct Geqrf>; -// Orgqr +//== Orthogonal QR ==// +//== Computes orthogonal matrix Q from QR Decomposition ==// + +// lapack orgqr template typename Orgqr::FnType* Orgqr::fn = nullptr; @@ -221,7 +294,9 @@ template struct Orgqr; template struct Orgqr>; template struct Orgqr>; -// Potrf +//== Cholesky Factorization ==// + +// lapack potrf template typename Potrf::FnType* Potrf::fn = nullptr; @@ -255,7 +330,40 @@ template struct Potrf; template struct Potrf>; template struct Potrf>; -// Gesdd +// FFI Kernel + +template +ffi::Error CholeskyFactorization::Kernel( + ffi::Buffer x, MatrixParams::UpLo uplo, + ffi::ResultBuffer x_out, ffi::ResultBuffer info) { + auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions); + auto* x_out_data = x_out->data; + auto* info_data = info->data; + + CopyIfDiffBuffer(x, x_out); + + auto uplo_v = static_cast(uplo); + auto x_order_v = CastNoOverflow(x.dimensions.back()); + auto x_leading_dim_v = x_order_v; + + const int64_t x_out_step{x_rows * x_cols}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, info_data); + x_out_data += x_out_step; + ++info_data; + } + return ffi::Error::Success(); +} + +template struct CholeskyFactorization; +template struct CholeskyFactorization; +template struct CholeskyFactorization; +template struct CholeskyFactorization; + +//== Singular Value Decomposition (SVD) ==// +//== using a divide and conquer method ==// + +// lapack gesdd static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { if (!job_opt_compute_uv) { @@ -267,7 +375,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { } lapack_int GesddIworkSize(int64_t m, int64_t n) { - return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n)); + return CastNoOverflow(8 * std::min(m, n), "gesdd iwork"); } template @@ -333,11 +441,12 @@ int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { int64_t mn = std::min(m, n); if (compute_uv == 0) { - return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn); + return CastNoOverflow(7 * mn, "complex gesdd rwork"); } int64_t mx = std::max(m, n); - return catch_lapack_int_overflow("complex gesdd rwork", - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)); + return CastNoOverflow( + std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), + "complex gesdd rwork"); } template @@ -408,13 +517,17 @@ template struct RealGesdd; template struct ComplexGesdd>; template struct ComplexGesdd>; +//== Eigenvalues and eigenvectors ==// + +// lapack syevd/heevd + // # Workspace sizes, taken from the LAPACK documentation. lapack_int SyevdWorkSize(int64_t n) { - return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n); + return CastNoOverflow(1 + 6 * n + 2 * n * n, "syevd lwork"); } lapack_int SyevdIworkSize(int64_t n) { - return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n); + return CastNoOverflow(3 + 5 * n, "syevd iwork"); } template @@ -454,11 +567,11 @@ void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { // Workspace sizes, taken from the LAPACK documentation. lapack_int HeevdWorkSize(int64_t n) { - return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n); + return CastNoOverflow(1 + 2 * n + n * n, "heevd work"); } lapack_int HeevdRworkSize(int64_t n) { - return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n); + return CastNoOverflow(1 + 5 * n + 2 * n * n, "heevd rwork"); } template @@ -534,6 +647,8 @@ static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed, } } +// lapack geev + template typename RealGeev::FnType* RealGeev::fn = nullptr; @@ -679,7 +794,9 @@ template struct RealGeev; template struct ComplexGeev>; template struct ComplexGeev>; -// Gees +//== Schur Decomposition ==// + +// lapack gees template typename RealGees::FnType* RealGees::fn = nullptr; @@ -809,6 +926,10 @@ template struct RealGees; template struct ComplexGees>; template struct ComplexGees>; +//== Hessenberg Decomposition ==// + +// lapack gehrd + template typename Gehrd::FnType* Gehrd::fn = nullptr; @@ -859,6 +980,10 @@ template struct Gehrd; template struct Gehrd>; template struct Gehrd>; +//== Tridiagonal Reduction ==// + +// lapack sytrd/hetrd + template typename Sytrd::FnType* Sytrd::fn = nullptr; diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 84d2251bb..86bfc8e94 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,19 +16,32 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ -#include #include +#include +#include +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/api/c_api.h" #include "xla/service/custom_call_status.h" -// Underlying function pointers (e.g., Trsm::Fn) are initialized either +// Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either // by the pybind wrapper that links them to an existing SciPy lapack instance, // or using the lapack_kernels_strong.cc static initialization to link them // directly to lapack for use in a pure C++ context. namespace jax { -typedef int lapack_int; +struct MatrixParams { + enum class Side : char { kLeft = 'L', kRight = 'R' }; + enum class UpLo : char { kLower = 'L', kUpper = 'U' }; + enum class Diag : char { kNonUnit = 'N', kUnit = 'U' }; + enum class Transpose : char { + kNoTrans = 'N', + kTrans = 'T', + kConjTrans = 'C' + }; +}; + template void AssignKernelFn(void* func) { KernelType::fn = reinterpret_cast(func); @@ -39,6 +52,34 @@ void AssignKernelFn(typename KernelType::FnType* func) { KernelType::fn = func; } +} // namespace jax + +#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \ + template <> \ + struct xla::ffi::AttrDecoding { \ + using Type = ATTR; \ + static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ + DiagnosticEngine& diagnostic); \ + } + +// XLA needs attributes to have deserialization method specified +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); + +#undef DEFINE_CHAR_ENUM_ATTR_DECODING + +namespace jax { + +using lapack_int = int; +inline constexpr auto LapackIntDtype = ::xla::ffi::DataType::S32; +static_assert( + std::is_same_v<::xla::ffi::NativeType, lapack_int>); + +//== Triangular System Solver ==// + +// lapack trsm template struct Trsm { @@ -50,6 +91,10 @@ struct Trsm { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +//== LU Decomposition ==// + +// lapack getrf + template struct Getrf { using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, @@ -59,6 +104,10 @@ struct Getrf { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +//== QR Factorization ==// + +// lapack geqrf + template struct Geqrf { using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, @@ -70,6 +119,10 @@ struct Geqrf { static int64_t Workspace(lapack_int m, lapack_int n); }; +//== Orthogonal QR ==// + +// lapack orgqr + template struct Orgqr { using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a, @@ -80,6 +133,10 @@ struct Orgqr { static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); }; +//== Cholesky Factorization ==// + +// lapack potrf + template struct Potrf { using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, @@ -88,6 +145,24 @@ struct Potrf { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +template <::xla::ffi::DataType dtype> +struct CholeskyFactorization { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(char* uplo, lapack_int* n, ValueType* a, lapack_int* lda, + lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer info); +}; + +//== Singular Value Decomposition (SVD) ==// + +// lapack gesdd + lapack_int GesddIworkSize(int64_t m, int64_t n); template @@ -119,6 +194,10 @@ struct ComplexGesdd { bool job_opt_full_matrices); }; +//== Eigenvalues and eigenvectors ==// + +// lapack syevd/heevd + lapack_int SyevdWorkSize(int64_t n); lapack_int SyevdIworkSize(int64_t n); @@ -145,6 +224,8 @@ struct ComplexHeevd { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +// lapack geev + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, @@ -165,6 +246,10 @@ struct ComplexGeev { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; +//== Schur Decomposition ==// + +// lapack gees + template struct RealGees { using FnType = void(char* jobvs, char* sort, bool (*select)(T, T), @@ -186,7 +271,11 @@ struct ComplexGees { static void Kernel(void* out, void** data, XlaCustomCallStatus*); }; -// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form. +//== Hessenberg Decomposition ==// +//== Reduces a non-symmetric square matrix to upper Hessenberg form ==// + +// lapack gehrd + template struct Gehrd { using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a, @@ -209,14 +298,16 @@ struct real_type> { typedef T type; }; -// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal -// form. +//== Tridiagonal Reduction ==// +//== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// + +// lapack sytrd/hetrd + template struct Sytrd { using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, typename real_type::type* d, - typename real_type::type* e, - T* tau, T* work, + typename real_type::type* e, T* tau, T* work, lapack_int* lwork, lapack_int* info); static FnType* fn; diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 4360ec55a..8d69e93f9 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but // a C++ user should link against LAPACK directly. This is needed when using // JAX-generated HLO from C++. +namespace ffi = xla::ffi; + extern "C" { jax::Trsm::FnType strsm_; @@ -41,10 +46,10 @@ jax::Orgqr::FnType dorgqr_; jax::Orgqr>::FnType cungqr_; jax::Orgqr>::FnType zungqr_; -jax::Potrf::FnType spotrf_; -jax::Potrf::FnType dpotrf_; -jax::Potrf>::FnType cpotrf_; -jax::Potrf>::FnType zpotrf_; +jax::CholeskyFactorization::FnType spotrf_; +jax::CholeskyFactorization::FnType dpotrf_; +jax::CholeskyFactorization::FnType cpotrf_; +jax::CholeskyFactorization::FnType zpotrf_; jax::RealGesdd::FnType sgesdd_; jax::RealGesdd::FnType dgesdd_; @@ -80,6 +85,27 @@ jax::Sytrd>::FnType zhetrd_; namespace jax { +#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" + +static_assert( + std::is_same_v::FnType, + jax::Potrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Potrf>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); + +#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG + static auto init = []() -> int { AssignKernelFn>(strsm_); AssignKernelFn>(dtrsm_); @@ -136,6 +162,13 @@ static auto init = []() -> int { AssignKernelFn>>(chetrd_); AssignKernelFn>>(zhetrd_); + // FFI Kernels + + AssignKernelFn>(spotrf_); + AssignKernelFn>(dpotrf_); + AssignKernelFn>(cpotrf_); + AssignKernelFn>(zpotrf_); + return 0; }();