Port Cholesky Factorization to XLA's FFI

This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 642954763
This commit is contained in:
Paweł Paruzel 2024-06-13 05:43:40 -07:00 committed by jax authors
parent c5c7fa7089
commit 3d39b6e752
7 changed files with 373 additions and 35 deletions

View File

@ -35,8 +35,12 @@ cc_library(
copts = ["-fexceptions"],
features = ["-use_header_modules"],
deps = [
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/strings:str_format",
],
)
@ -50,6 +54,7 @@ cc_library(
pybind_extension(
name = "_lapack",
srcs = ["lapack.cc"],
hdrs = ["lapack.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
@ -63,6 +68,7 @@ pybind_extension(
deps = [
":lapack_kernels",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/ffi/api:ffi",
"@nanobind",
],
)
@ -70,10 +76,13 @@ pybind_extension(
cc_library(
name = "cpu_kernels",
srcs = ["cpu_kernels.cc"],
hdrs = ["lapack.h"],
visibility = ["//visibility:public"],
deps = [
":lapack_kernels",
":lapack_kernels_using_lapack",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_target_registry",
],
alwayslink = 1,

View File

@ -15,13 +15,24 @@ limitations under the License.
// This file is not used by JAX itself, but exists to assist with running
// JAX-generated HLO code from outside of JAX.
#include <complex>
#include "jaxlib/cpu/lapack.h"
#include "jaxlib/cpu/lapack_kernels.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_target_registry.h"
#define JAX_CPU_REGISTER_HANDLER(name) \
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name);
namespace jax {
namespace {
// Old-style kernels
// TODO(b/344892332): To be removed after the 6M compatibility period is over.
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm<float>::Kernel,
"Host");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm<double>::Kernel,
@ -105,5 +116,14 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
"lapack_zgees", ComplexGees<std::complex<double>>::Kernel, "Host");
// FFI Kernels
JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zpotrf_ffi);
#undef JAX_CPU_REGISTER_HANDLER
} // namespace
} // namespace jax

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cpu/lapack.h"
#include <complex>
#include "nanobind/nanobind.h"
@ -24,6 +26,8 @@ namespace {
namespace nb = nanobind;
using ::xla::ffi::DataType;
void GetLapackKernelsFromScipy() {
static bool initialized = false; // Protected by GIL
if (initialized) return;
@ -66,6 +70,10 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<Potrf<double>>(lapack_ptr("dpotrf"));
AssignKernelFn<Potrf<std::complex<float>>>(lapack_ptr("cpotrf"));
AssignKernelFn<Potrf<std::complex<double>>>(lapack_ptr("zpotrf"));
AssignKernelFn<CholeskyFactorization<DataType::F32>>(lapack_ptr("spotrf"));
AssignKernelFn<CholeskyFactorization<DataType::F64>>(lapack_ptr("dpotrf"));
AssignKernelFn<CholeskyFactorization<DataType::C64>>(lapack_ptr("cpotrf"));
AssignKernelFn<CholeskyFactorization<DataType::C128>>(lapack_ptr("zpotrf"));
AssignKernelFn<RealGesdd<float>>(lapack_ptr("sgesdd"));
AssignKernelFn<RealGesdd<double>>(lapack_ptr("dgesdd"));
@ -170,14 +178,20 @@ nb::dict Registrations() {
dict["lapack_zhetrd"] =
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi);
dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi);
dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi);
dict["lapack_zpotrf_ffi"] = EncapsulateFunction(lapack_zpotrf_ffi);
return dict;
}
NB_MODULE(_lapack, m) {
// Populates the LAPACK kernels from scipy on first call.
m.def("initialize", GetLapackKernelsFromScipy);
m.def("registrations", &Registrations);
// Old-style LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
nb::arg("n"));
m.def("lapack_dgeqrf_workspace", &Geqrf<double>::Workspace, nb::arg("m"),

46
jaxlib/cpu/lapack.h Normal file
View File

@ -0,0 +1,46 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_CPU_LAPACK_H_
#define JAXLIB_CPU_LAPACK_H_
#include "jaxlib/cpu/lapack_kernels.h"
#include "xla/ffi/api/ffi.h"
namespace jax {
// FFI Definition Macros (by DataType)
#define JAX_CPU_DEFINE_POTRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER( \
name, CholeskyFactorization<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Attr<MatrixParams::UpLo>("uplo") \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
// FFI Handlers
JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_POTRF
} // namespace jax
#endif // JAXLIB_CPU_LAPACK_H_

View File

@ -15,32 +15,98 @@ limitations under the License.
#include "jaxlib/cpu/lapack_kernels.h"
#include <algorithm>
#include <cmath>
#include <complex>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <functional>
#include <limits>
#include <optional>
#include <stdexcept>
#include <string>
#include <tuple>
#include <type_traits>
#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/strings/str_format.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),
"Expected LAPACK integers to be 32-bit");
namespace ffi = xla::ffi;
namespace {
inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) {
if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) {
template <typename T>
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
if constexpr (sizeof(T) == sizeof(int64_t)) {
return value;
} else {
if (value > std::numeric_limits<jax::lapack_int>::max()) {
throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int");
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
throw std::overflow_error{
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
"value of the desired type",
source, value)};
}
return value;
return static_cast<T>(value);
}
}
template <typename T>
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
if (dims.size() < 2) {
throw std::invalid_argument("Matrix must have at least 2 dimensions");
}
auto matrix_dims = dims.last(2);
return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1,
std::multiplies<int64_t>()),
matrix_dims.front(), matrix_dims.back());
}
template <ffi::DataType dtype>
void CopyIfDiffBuffer(ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
if (x.data != x_out->data) {
const auto x_size = batch_count * x_rows * x_cols;
std::copy_n(x.data, x_size, x_out->data);
}
}
} // namespace
#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \
std::optional<type> xla::ffi::AttrDecoding<type>::Decode( \
XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \
if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \
return diagnostic.Emit("Wrong attribute type: expected ") \
<< XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \
} \
auto* scalar = reinterpret_cast<XLA_FFI_Scalar*>(attr); \
if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \
return diagnostic.Emit("Wrong scalar data type: expected ") \
<< XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \
} \
auto underlying = \
*reinterpret_cast<std::underlying_type_t<type>*>(scalar->value); \
return static_cast<type>(underlying); \
}
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
#undef REGISTER_CHAR_ENUM_ATTR_DECODING
namespace jax {
static_assert(sizeof(lapack_int) == sizeof(int32_t),
"Expected LAPACK integers to be 32-bit");
//== Triangular System Solver ==//
// lapack trsm
template <typename T>
typename Trsm<T>::FnType* Trsm<T>::fn = nullptr;
@ -92,7 +158,9 @@ template struct Trsm<double>;
template struct Trsm<std::complex<float>>;
template struct Trsm<std::complex<double>>;
// Getrf
//== LU Decomposition ==//
// lapack getrf
template <typename T>
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;
@ -126,7 +194,9 @@ template struct Getrf<double>;
template struct Getrf<std::complex<float>>;
template struct Getrf<std::complex<double>>;
// Geqrf
//== QR Factorization ==//
// lapack geqrf
template <typename T>
typename Geqrf<T>::FnType* Geqrf<T>::fn = nullptr;
@ -173,7 +243,10 @@ template struct Geqrf<double>;
template struct Geqrf<std::complex<float>>;
template struct Geqrf<std::complex<double>>;
// Orgqr
//== Orthogonal QR ==//
//== Computes orthogonal matrix Q from QR Decomposition ==//
// lapack orgqr
template <typename T>
typename Orgqr<T>::FnType* Orgqr<T>::fn = nullptr;
@ -221,7 +294,9 @@ template struct Orgqr<double>;
template struct Orgqr<std::complex<float>>;
template struct Orgqr<std::complex<double>>;
// Potrf
//== Cholesky Factorization ==//
// lapack potrf
template <typename T>
typename Potrf<T>::FnType* Potrf<T>::fn = nullptr;
@ -255,7 +330,40 @@ template struct Potrf<double>;
template struct Potrf<std::complex<float>>;
template struct Potrf<std::complex<double>>;
// Gesdd
// FFI Kernel
template <ffi::DataType dtype>
ffi::Error CholeskyFactorization<dtype>::Kernel(
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
auto* x_out_data = x_out->data;
auto* info_data = info->data;
CopyIfDiffBuffer(x, x_out);
auto uplo_v = static_cast<char>(uplo);
auto x_order_v = CastNoOverflow<lapack_int>(x.dimensions.back());
auto x_leading_dim_v = x_order_v;
const int64_t x_out_step{x_rows * x_cols};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, info_data);
x_out_data += x_out_step;
++info_data;
}
return ffi::Error::Success();
}
template struct CholeskyFactorization<ffi::DataType::F32>;
template struct CholeskyFactorization<ffi::DataType::F64>;
template struct CholeskyFactorization<ffi::DataType::C64>;
template struct CholeskyFactorization<ffi::DataType::C128>;
//== Singular Value Decomposition (SVD) ==//
//== using a divide and conquer method ==//
// lapack gesdd
static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) {
if (!job_opt_compute_uv) {
@ -267,7 +375,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) {
}
lapack_int GesddIworkSize(int64_t m, int64_t n) {
return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n));
return CastNoOverflow<lapack_int>(8 * std::min(m, n), "gesdd iwork");
}
template <typename T>
@ -333,11 +441,12 @@ int64_t RealGesdd<T>::Workspace(lapack_int m, lapack_int n,
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) {
int64_t mn = std::min(m, n);
if (compute_uv == 0) {
return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn);
return CastNoOverflow<lapack_int>(7 * mn, "complex gesdd rwork");
}
int64_t mx = std::max(m, n);
return catch_lapack_int_overflow("complex gesdd rwork",
std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn));
return CastNoOverflow<lapack_int>(
std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn),
"complex gesdd rwork");
}
template <typename T>
@ -408,13 +517,17 @@ template struct RealGesdd<double>;
template struct ComplexGesdd<std::complex<float>>;
template struct ComplexGesdd<std::complex<double>>;
//== Eigenvalues and eigenvectors ==//
// lapack syevd/heevd
// # Workspace sizes, taken from the LAPACK documentation.
lapack_int SyevdWorkSize(int64_t n) {
return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n);
return CastNoOverflow<lapack_int>(1 + 6 * n + 2 * n * n, "syevd lwork");
}
lapack_int SyevdIworkSize(int64_t n) {
return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n);
return CastNoOverflow<lapack_int>(3 + 5 * n, "syevd iwork");
}
template <typename T>
@ -454,11 +567,11 @@ void RealSyevd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
// Workspace sizes, taken from the LAPACK documentation.
lapack_int HeevdWorkSize(int64_t n) {
return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n);
return CastNoOverflow<lapack_int>(1 + 2 * n + n * n, "heevd work");
}
lapack_int HeevdRworkSize(int64_t n) {
return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n);
return CastNoOverflow<lapack_int>(1 + 5 * n + 2 * n * n, "heevd rwork");
}
template <typename T>
@ -534,6 +647,8 @@ static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed,
}
}
// lapack geev
template <typename T>
typename RealGeev<T>::FnType* RealGeev<T>::fn = nullptr;
@ -679,7 +794,9 @@ template struct RealGeev<double>;
template struct ComplexGeev<std::complex<float>>;
template struct ComplexGeev<std::complex<double>>;
// Gees
//== Schur Decomposition ==//
// lapack gees
template <typename T>
typename RealGees<T>::FnType* RealGees<T>::fn = nullptr;
@ -809,6 +926,10 @@ template struct RealGees<double>;
template struct ComplexGees<std::complex<float>>;
template struct ComplexGees<std::complex<double>>;
//== Hessenberg Decomposition ==//
// lapack gehrd
template <typename T>
typename Gehrd<T>::FnType* Gehrd<T>::fn = nullptr;
@ -859,6 +980,10 @@ template struct Gehrd<double>;
template struct Gehrd<std::complex<float>>;
template struct Gehrd<std::complex<double>>;
//== Tridiagonal Reduction ==//
// lapack sytrd/hetrd
template <typename T>
typename Sytrd<T>::FnType* Sytrd<T>::fn = nullptr;

View File

@ -16,19 +16,32 @@ limitations under the License.
#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_
#define JAXLIB_CPU_LAPACK_KERNELS_H_
#include <complex>
#include <cstdint>
#include <optional>
#include <type_traits>
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/api/c_api.h"
#include "xla/service/custom_call_status.h"
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
// Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either
// by the pybind wrapper that links them to an existing SciPy lapack instance,
// or using the lapack_kernels_strong.cc static initialization to link them
// directly to lapack for use in a pure C++ context.
namespace jax {
typedef int lapack_int;
struct MatrixParams {
enum class Side : char { kLeft = 'L', kRight = 'R' };
enum class UpLo : char { kLower = 'L', kUpper = 'U' };
enum class Diag : char { kNonUnit = 'N', kUnit = 'U' };
enum class Transpose : char {
kNoTrans = 'N',
kTrans = 'T',
kConjTrans = 'C'
};
};
template <typename KernelType>
void AssignKernelFn(void* func) {
KernelType::fn = reinterpret_cast<typename KernelType::FnType*>(func);
@ -39,6 +52,34 @@ void AssignKernelFn(typename KernelType::FnType* func) {
KernelType::fn = func;
}
} // namespace jax
#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \
template <> \
struct xla::ffi::AttrDecoding<ATTR> { \
using Type = ATTR; \
static std::optional<Type> Decode(XLA_FFI_AttrType type, void* attr, \
DiagnosticEngine& diagnostic); \
}
// XLA needs attributes to have deserialization method specified
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
#undef DEFINE_CHAR_ENUM_ATTR_DECODING
namespace jax {
using lapack_int = int;
inline constexpr auto LapackIntDtype = ::xla::ffi::DataType::S32;
static_assert(
std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
//== Triangular System Solver ==//
// lapack trsm
template <typename T>
struct Trsm {
@ -50,6 +91,10 @@ struct Trsm {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
//== LU Decomposition ==//
// lapack getrf
template <typename T>
struct Getrf {
using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda,
@ -59,6 +104,10 @@ struct Getrf {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
//== QR Factorization ==//
// lapack geqrf
template <typename T>
struct Geqrf {
using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda,
@ -70,6 +119,10 @@ struct Geqrf {
static int64_t Workspace(lapack_int m, lapack_int n);
};
//== Orthogonal QR ==//
// lapack orgqr
template <typename T>
struct Orgqr {
using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a,
@ -80,6 +133,10 @@ struct Orgqr {
static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k);
};
//== Cholesky Factorization ==//
// lapack potrf
template <typename T>
struct Potrf {
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
@ -88,6 +145,24 @@ struct Potrf {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
template <::xla::ffi::DataType dtype>
struct CholeskyFactorization {
using ValueType = ::xla::ffi::NativeType<dtype>;
using FnType = void(char* uplo, lapack_int* n, ValueType* a, lapack_int* lda,
lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<LapackIntDtype> info);
};
//== Singular Value Decomposition (SVD) ==//
// lapack gesdd
lapack_int GesddIworkSize(int64_t m, int64_t n);
template <typename T>
@ -119,6 +194,10 @@ struct ComplexGesdd {
bool job_opt_full_matrices);
};
//== Eigenvalues and eigenvectors ==//
// lapack syevd/heevd
lapack_int SyevdWorkSize(int64_t n);
lapack_int SyevdIworkSize(int64_t n);
@ -145,6 +224,8 @@ struct ComplexHeevd {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
// lapack geev
template <typename T>
struct RealGeev {
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a,
@ -165,6 +246,10 @@ struct ComplexGeev {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
//== Schur Decomposition ==//
// lapack gees
template <typename T>
struct RealGees {
using FnType = void(char* jobvs, char* sort, bool (*select)(T, T),
@ -186,7 +271,11 @@ struct ComplexGees {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form.
//== Hessenberg Decomposition ==//
//== Reduces a non-symmetric square matrix to upper Hessenberg form ==//
// lapack gehrd
template <typename T>
struct Gehrd {
using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a,
@ -209,14 +298,16 @@ struct real_type<std::complex<T>> {
typedef T type;
};
// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal
// form.
//== Tridiagonal Reduction ==//
//== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==//
// lapack sytrd/hetrd
template <typename T>
struct Sytrd {
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
typename real_type<T>::type* d,
typename real_type<T>::type* e,
T* tau, T* work,
typename real_type<T>::type* e, T* tau, T* work,
lapack_int* lwork, lapack_int* info);
static FnType* fn;

View File

@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include <type_traits>
#include "jaxlib/cpu/lapack_kernels.h"
// From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but
// a C++ user should link against LAPACK directly. This is needed when using
// JAX-generated HLO from C++.
namespace ffi = xla::ffi;
extern "C" {
jax::Trsm<float>::FnType strsm_;
@ -41,10 +46,10 @@ jax::Orgqr<double>::FnType dorgqr_;
jax::Orgqr<std::complex<float>>::FnType cungqr_;
jax::Orgqr<std::complex<double>>::FnType zungqr_;
jax::Potrf<float>::FnType spotrf_;
jax::Potrf<double>::FnType dpotrf_;
jax::Potrf<std::complex<float>>::FnType cpotrf_;
jax::Potrf<std::complex<double>>::FnType zpotrf_;
jax::CholeskyFactorization<ffi::DataType::F32>::FnType spotrf_;
jax::CholeskyFactorization<ffi::DataType::F64>::FnType dpotrf_;
jax::CholeskyFactorization<ffi::DataType::C64>::FnType cpotrf_;
jax::CholeskyFactorization<ffi::DataType::C128>::FnType zpotrf_;
jax::RealGesdd<float>::FnType sgesdd_;
jax::RealGesdd<double>::FnType dgesdd_;
@ -80,6 +85,27 @@ jax::Sytrd<std::complex<double>>::FnType zhetrd_;
namespace jax {
#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch"
static_assert(
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::F32>::FnType,
jax::Potrf<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::F64>::FnType,
jax::Potrf<double>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::C64>::FnType,
jax::Potrf<std::complex<float>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::C128>::FnType,
jax::Potrf<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG
static auto init = []() -> int {
AssignKernelFn<Trsm<float>>(strsm_);
AssignKernelFn<Trsm<double>>(dtrsm_);
@ -136,6 +162,13 @@ static auto init = []() -> int {
AssignKernelFn<Sytrd<std::complex<float>>>(chetrd_);
AssignKernelFn<Sytrd<std::complex<double>>>(zhetrd_);
// FFI Kernels
AssignKernelFn<CholeskyFactorization<ffi::DataType::F32>>(spotrf_);
AssignKernelFn<CholeskyFactorization<ffi::DataType::F64>>(dpotrf_);
AssignKernelFn<CholeskyFactorization<ffi::DataType::C64>>(cpotrf_);
AssignKernelFn<CholeskyFactorization<ffi::DataType::C128>>(zpotrf_);
return 0;
}();