mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Port Cholesky Factorization to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 642954763
This commit is contained in:
parent
c5c7fa7089
commit
3d39b6e752
@ -35,8 +35,12 @@ cc_library(
|
||||
copts = ["-fexceptions"],
|
||||
features = ["-use_header_modules"],
|
||||
deps = [
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
@ -50,6 +54,7 @@ cc_library(
|
||||
pybind_extension(
|
||||
name = "_lapack",
|
||||
srcs = ["lapack.cc"],
|
||||
hdrs = ["lapack.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
@ -63,6 +68,7 @@ pybind_extension(
|
||||
deps = [
|
||||
":lapack_kernels",
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
@ -70,10 +76,13 @@ pybind_extension(
|
||||
cc_library(
|
||||
name = "cpu_kernels",
|
||||
srcs = ["cpu_kernels.cc"],
|
||||
hdrs = ["lapack.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lapack_kernels",
|
||||
":lapack_kernels_using_lapack",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_target_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -15,13 +15,24 @@ limitations under the License.
|
||||
|
||||
// This file is not used by JAX itself, but exists to assist with running
|
||||
// JAX-generated HLO code from outside of JAX.
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "jaxlib/cpu/lapack.h"
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/service/custom_call_target_registry.h"
|
||||
|
||||
#define JAX_CPU_REGISTER_HANDLER(name) \
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name);
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
// Old-style kernels
|
||||
// TODO(b/344892332): To be removed after the 6M compatibility period is over.
|
||||
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm<float>::Kernel,
|
||||
"Host");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm<double>::Kernel,
|
||||
@ -105,5 +116,14 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
|
||||
"lapack_zgees", ComplexGees<std::complex<double>>::Kernel, "Host");
|
||||
|
||||
// FFI Kernels
|
||||
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zpotrf_ffi);
|
||||
|
||||
#undef JAX_CPU_REGISTER_HANDLER
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/cpu/lapack.h"
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
@ -24,6 +26,8 @@ namespace {
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
using ::xla::ffi::DataType;
|
||||
|
||||
void GetLapackKernelsFromScipy() {
|
||||
static bool initialized = false; // Protected by GIL
|
||||
if (initialized) return;
|
||||
@ -66,6 +70,10 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<Potrf<double>>(lapack_ptr("dpotrf"));
|
||||
AssignKernelFn<Potrf<std::complex<float>>>(lapack_ptr("cpotrf"));
|
||||
AssignKernelFn<Potrf<std::complex<double>>>(lapack_ptr("zpotrf"));
|
||||
AssignKernelFn<CholeskyFactorization<DataType::F32>>(lapack_ptr("spotrf"));
|
||||
AssignKernelFn<CholeskyFactorization<DataType::F64>>(lapack_ptr("dpotrf"));
|
||||
AssignKernelFn<CholeskyFactorization<DataType::C64>>(lapack_ptr("cpotrf"));
|
||||
AssignKernelFn<CholeskyFactorization<DataType::C128>>(lapack_ptr("zpotrf"));
|
||||
|
||||
AssignKernelFn<RealGesdd<float>>(lapack_ptr("sgesdd"));
|
||||
AssignKernelFn<RealGesdd<double>>(lapack_ptr("dgesdd"));
|
||||
@ -170,14 +178,20 @@ nb::dict Registrations() {
|
||||
dict["lapack_zhetrd"] =
|
||||
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
|
||||
|
||||
dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi);
|
||||
dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi);
|
||||
dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi);
|
||||
dict["lapack_zpotrf_ffi"] = EncapsulateFunction(lapack_zpotrf_ffi);
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
||||
NB_MODULE(_lapack, m) {
|
||||
// Populates the LAPACK kernels from scipy on first call.
|
||||
m.def("initialize", GetLapackKernelsFromScipy);
|
||||
|
||||
m.def("registrations", &Registrations);
|
||||
|
||||
// Old-style LAPACK Workspace Size Queries
|
||||
m.def("lapack_sgeqrf_workspace", &Geqrf<float>::Workspace, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_dgeqrf_workspace", &Geqrf<double>::Workspace, nb::arg("m"),
|
||||
|
46
jaxlib/cpu/lapack.h
Normal file
46
jaxlib/cpu/lapack.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_CPU_LAPACK_H_
|
||||
#define JAXLIB_CPU_LAPACK_H_
|
||||
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
// FFI Definition Macros (by DataType)
|
||||
|
||||
#define JAX_CPU_DEFINE_POTRF(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER( \
|
||||
name, CholeskyFactorization<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Attr<MatrixParams::UpLo>("uplo") \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
// FFI Handlers
|
||||
|
||||
JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_zpotrf_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
#undef JAX_CPU_DEFINE_POTRF
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CPU_LAPACK_H_
|
@ -15,32 +15,98 @@ limitations under the License.
|
||||
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
static_assert(sizeof(jax::lapack_int) == sizeof(int32_t),
|
||||
"Expected LAPACK integers to be 32-bit");
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
namespace {
|
||||
|
||||
inline int64_t catch_lapack_int_overflow(const std::string& source, int64_t value) {
|
||||
if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) {
|
||||
template <typename T>
|
||||
inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) {
|
||||
if constexpr (sizeof(T) == sizeof(int64_t)) {
|
||||
return value;
|
||||
} else {
|
||||
if (value > std::numeric_limits<jax::lapack_int>::max()) {
|
||||
throw std::overflow_error(source + "(=" + std::to_string(value) + ") exceeds maximum value of jax::lapack_int");
|
||||
if (value > std::numeric_limits<T>::max()) [[unlikely]] {
|
||||
throw std::overflow_error{
|
||||
absl::StrFormat("%s: Value (=%d) exceeds the maximum representable "
|
||||
"value of the desired type",
|
||||
source, value)};
|
||||
}
|
||||
return value;
|
||||
return static_cast<T>(value);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::tuple<int64_t, int64_t, int64_t> SplitBatch2D(ffi::Span<T> dims) {
|
||||
if (dims.size() < 2) {
|
||||
throw std::invalid_argument("Matrix must have at least 2 dimensions");
|
||||
}
|
||||
auto matrix_dims = dims.last(2);
|
||||
return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1,
|
||||
std::multiplies<int64_t>()),
|
||||
matrix_dims.front(), matrix_dims.back());
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
void CopyIfDiffBuffer(ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
|
||||
if (x.data != x_out->data) {
|
||||
const auto x_size = batch_count * x_rows * x_cols;
|
||||
std::copy_n(x.data, x_size, x_out->data);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \
|
||||
std::optional<type> xla::ffi::AttrDecoding<type>::Decode( \
|
||||
XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \
|
||||
if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \
|
||||
return diagnostic.Emit("Wrong attribute type: expected ") \
|
||||
<< XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \
|
||||
} \
|
||||
auto* scalar = reinterpret_cast<XLA_FFI_Scalar*>(attr); \
|
||||
if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \
|
||||
return diagnostic.Emit("Wrong scalar data type: expected ") \
|
||||
<< XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \
|
||||
} \
|
||||
auto underlying = \
|
||||
*reinterpret_cast<std::underlying_type_t<type>*>(scalar->value); \
|
||||
return static_cast<type>(underlying); \
|
||||
}
|
||||
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
|
||||
REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
|
||||
|
||||
#undef REGISTER_CHAR_ENUM_ATTR_DECODING
|
||||
|
||||
namespace jax {
|
||||
|
||||
static_assert(sizeof(lapack_int) == sizeof(int32_t),
|
||||
"Expected LAPACK integers to be 32-bit");
|
||||
//== Triangular System Solver ==//
|
||||
|
||||
// lapack trsm
|
||||
|
||||
template <typename T>
|
||||
typename Trsm<T>::FnType* Trsm<T>::fn = nullptr;
|
||||
@ -92,7 +158,9 @@ template struct Trsm<double>;
|
||||
template struct Trsm<std::complex<float>>;
|
||||
template struct Trsm<std::complex<double>>;
|
||||
|
||||
// Getrf
|
||||
//== LU Decomposition ==//
|
||||
|
||||
// lapack getrf
|
||||
|
||||
template <typename T>
|
||||
typename Getrf<T>::FnType* Getrf<T>::fn = nullptr;
|
||||
@ -126,7 +194,9 @@ template struct Getrf<double>;
|
||||
template struct Getrf<std::complex<float>>;
|
||||
template struct Getrf<std::complex<double>>;
|
||||
|
||||
// Geqrf
|
||||
//== QR Factorization ==//
|
||||
|
||||
// lapack geqrf
|
||||
|
||||
template <typename T>
|
||||
typename Geqrf<T>::FnType* Geqrf<T>::fn = nullptr;
|
||||
@ -173,7 +243,10 @@ template struct Geqrf<double>;
|
||||
template struct Geqrf<std::complex<float>>;
|
||||
template struct Geqrf<std::complex<double>>;
|
||||
|
||||
// Orgqr
|
||||
//== Orthogonal QR ==//
|
||||
//== Computes orthogonal matrix Q from QR Decomposition ==//
|
||||
|
||||
// lapack orgqr
|
||||
|
||||
template <typename T>
|
||||
typename Orgqr<T>::FnType* Orgqr<T>::fn = nullptr;
|
||||
@ -221,7 +294,9 @@ template struct Orgqr<double>;
|
||||
template struct Orgqr<std::complex<float>>;
|
||||
template struct Orgqr<std::complex<double>>;
|
||||
|
||||
// Potrf
|
||||
//== Cholesky Factorization ==//
|
||||
|
||||
// lapack potrf
|
||||
|
||||
template <typename T>
|
||||
typename Potrf<T>::FnType* Potrf<T>::fn = nullptr;
|
||||
@ -255,7 +330,40 @@ template struct Potrf<double>;
|
||||
template struct Potrf<std::complex<float>>;
|
||||
template struct Potrf<std::complex<double>>;
|
||||
|
||||
// Gesdd
|
||||
// FFI Kernel
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error CholeskyFactorization<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
ffi::ResultBuffer<dtype> x_out, ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions);
|
||||
auto* x_out_data = x_out->data;
|
||||
auto* info_data = info->data;
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto uplo_v = static_cast<char>(uplo);
|
||||
auto x_order_v = CastNoOverflow<lapack_int>(x.dimensions.back());
|
||||
auto x_leading_dim_v = x_order_v;
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&uplo_v, &x_order_v, x_out_data, &x_leading_dim_v, info_data);
|
||||
x_out_data += x_out_step;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template struct CholeskyFactorization<ffi::DataType::F32>;
|
||||
template struct CholeskyFactorization<ffi::DataType::F64>;
|
||||
template struct CholeskyFactorization<ffi::DataType::C64>;
|
||||
template struct CholeskyFactorization<ffi::DataType::C128>;
|
||||
|
||||
//== Singular Value Decomposition (SVD) ==//
|
||||
//== using a divide and conquer method ==//
|
||||
|
||||
// lapack gesdd
|
||||
|
||||
static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) {
|
||||
if (!job_opt_compute_uv) {
|
||||
@ -267,7 +375,7 @@ static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) {
|
||||
}
|
||||
|
||||
lapack_int GesddIworkSize(int64_t m, int64_t n) {
|
||||
return catch_lapack_int_overflow("gesdd iwork", 8 * std::min(m, n));
|
||||
return CastNoOverflow<lapack_int>(8 * std::min(m, n), "gesdd iwork");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -333,11 +441,12 @@ int64_t RealGesdd<T>::Workspace(lapack_int m, lapack_int n,
|
||||
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) {
|
||||
int64_t mn = std::min(m, n);
|
||||
if (compute_uv == 0) {
|
||||
return catch_lapack_int_overflow("complex gesdd rwork", 7 * mn);
|
||||
return CastNoOverflow<lapack_int>(7 * mn, "complex gesdd rwork");
|
||||
}
|
||||
int64_t mx = std::max(m, n);
|
||||
return catch_lapack_int_overflow("complex gesdd rwork",
|
||||
std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn));
|
||||
return CastNoOverflow<lapack_int>(
|
||||
std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn),
|
||||
"complex gesdd rwork");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -408,13 +517,17 @@ template struct RealGesdd<double>;
|
||||
template struct ComplexGesdd<std::complex<float>>;
|
||||
template struct ComplexGesdd<std::complex<double>>;
|
||||
|
||||
//== Eigenvalues and eigenvectors ==//
|
||||
|
||||
// lapack syevd/heevd
|
||||
|
||||
// # Workspace sizes, taken from the LAPACK documentation.
|
||||
lapack_int SyevdWorkSize(int64_t n) {
|
||||
return catch_lapack_int_overflow("syevd lwork", 1 + 6 * n + 2 * n * n);
|
||||
return CastNoOverflow<lapack_int>(1 + 6 * n + 2 * n * n, "syevd lwork");
|
||||
}
|
||||
|
||||
lapack_int SyevdIworkSize(int64_t n) {
|
||||
return catch_lapack_int_overflow("syevd iwork", 3 + 5 * n);
|
||||
return CastNoOverflow<lapack_int>(3 + 5 * n, "syevd iwork");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -454,11 +567,11 @@ void RealSyevd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
|
||||
// Workspace sizes, taken from the LAPACK documentation.
|
||||
lapack_int HeevdWorkSize(int64_t n) {
|
||||
return catch_lapack_int_overflow("heevd work", 1 + 2 * n + n * n);
|
||||
return CastNoOverflow<lapack_int>(1 + 2 * n + n * n, "heevd work");
|
||||
}
|
||||
|
||||
lapack_int HeevdRworkSize(int64_t n) {
|
||||
return catch_lapack_int_overflow("heevd rwork", 1 + 5 * n + 2 * n * n);
|
||||
return CastNoOverflow<lapack_int>(1 + 5 * n + 2 * n * n, "heevd rwork");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -534,6 +647,8 @@ static void UnpackEigenvectors(int n, const T* im_eigenvalues, const T* packed,
|
||||
}
|
||||
}
|
||||
|
||||
// lapack geev
|
||||
|
||||
template <typename T>
|
||||
typename RealGeev<T>::FnType* RealGeev<T>::fn = nullptr;
|
||||
|
||||
@ -679,7 +794,9 @@ template struct RealGeev<double>;
|
||||
template struct ComplexGeev<std::complex<float>>;
|
||||
template struct ComplexGeev<std::complex<double>>;
|
||||
|
||||
// Gees
|
||||
//== Schur Decomposition ==//
|
||||
|
||||
// lapack gees
|
||||
|
||||
template <typename T>
|
||||
typename RealGees<T>::FnType* RealGees<T>::fn = nullptr;
|
||||
@ -809,6 +926,10 @@ template struct RealGees<double>;
|
||||
template struct ComplexGees<std::complex<float>>;
|
||||
template struct ComplexGees<std::complex<double>>;
|
||||
|
||||
//== Hessenberg Decomposition ==//
|
||||
|
||||
// lapack gehrd
|
||||
|
||||
template <typename T>
|
||||
typename Gehrd<T>::FnType* Gehrd<T>::fn = nullptr;
|
||||
|
||||
@ -859,6 +980,10 @@ template struct Gehrd<double>;
|
||||
template struct Gehrd<std::complex<float>>;
|
||||
template struct Gehrd<std::complex<double>>;
|
||||
|
||||
//== Tridiagonal Reduction ==//
|
||||
|
||||
// lapack sytrd/hetrd
|
||||
|
||||
template <typename T>
|
||||
typename Sytrd<T>::FnType* Sytrd<T>::fn = nullptr;
|
||||
|
||||
|
@ -16,19 +16,32 @@ limitations under the License.
|
||||
#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_
|
||||
#define JAXLIB_CPU_LAPACK_KERNELS_H_
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
|
||||
// Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either
|
||||
// by the pybind wrapper that links them to an existing SciPy lapack instance,
|
||||
// or using the lapack_kernels_strong.cc static initialization to link them
|
||||
// directly to lapack for use in a pure C++ context.
|
||||
|
||||
namespace jax {
|
||||
|
||||
typedef int lapack_int;
|
||||
struct MatrixParams {
|
||||
enum class Side : char { kLeft = 'L', kRight = 'R' };
|
||||
enum class UpLo : char { kLower = 'L', kUpper = 'U' };
|
||||
enum class Diag : char { kNonUnit = 'N', kUnit = 'U' };
|
||||
enum class Transpose : char {
|
||||
kNoTrans = 'N',
|
||||
kTrans = 'T',
|
||||
kConjTrans = 'C'
|
||||
};
|
||||
};
|
||||
|
||||
template <typename KernelType>
|
||||
void AssignKernelFn(void* func) {
|
||||
KernelType::fn = reinterpret_cast<typename KernelType::FnType*>(func);
|
||||
@ -39,6 +52,34 @@ void AssignKernelFn(typename KernelType::FnType* func) {
|
||||
KernelType::fn = func;
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \
|
||||
template <> \
|
||||
struct xla::ffi::AttrDecoding<ATTR> { \
|
||||
using Type = ATTR; \
|
||||
static std::optional<Type> Decode(XLA_FFI_AttrType type, void* attr, \
|
||||
DiagnosticEngine& diagnostic); \
|
||||
}
|
||||
|
||||
// XLA needs attributes to have deserialization method specified
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose);
|
||||
DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag);
|
||||
|
||||
#undef DEFINE_CHAR_ENUM_ATTR_DECODING
|
||||
|
||||
namespace jax {
|
||||
|
||||
using lapack_int = int;
|
||||
inline constexpr auto LapackIntDtype = ::xla::ffi::DataType::S32;
|
||||
static_assert(
|
||||
std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
|
||||
|
||||
//== Triangular System Solver ==//
|
||||
|
||||
// lapack trsm
|
||||
|
||||
template <typename T>
|
||||
struct Trsm {
|
||||
@ -50,6 +91,10 @@ struct Trsm {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
//== LU Decomposition ==//
|
||||
|
||||
// lapack getrf
|
||||
|
||||
template <typename T>
|
||||
struct Getrf {
|
||||
using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda,
|
||||
@ -59,6 +104,10 @@ struct Getrf {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
//== QR Factorization ==//
|
||||
|
||||
// lapack geqrf
|
||||
|
||||
template <typename T>
|
||||
struct Geqrf {
|
||||
using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda,
|
||||
@ -70,6 +119,10 @@ struct Geqrf {
|
||||
static int64_t Workspace(lapack_int m, lapack_int n);
|
||||
};
|
||||
|
||||
//== Orthogonal QR ==//
|
||||
|
||||
// lapack orgqr
|
||||
|
||||
template <typename T>
|
||||
struct Orgqr {
|
||||
using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a,
|
||||
@ -80,6 +133,10 @@ struct Orgqr {
|
||||
static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k);
|
||||
};
|
||||
|
||||
//== Cholesky Factorization ==//
|
||||
|
||||
// lapack potrf
|
||||
|
||||
template <typename T>
|
||||
struct Potrf {
|
||||
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
|
||||
@ -88,6 +145,24 @@ struct Potrf {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct CholeskyFactorization {
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using FnType = void(char* uplo, lapack_int* n, ValueType* a, lapack_int* lda,
|
||||
lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, MatrixParams::UpLo uplo,
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
};
|
||||
|
||||
//== Singular Value Decomposition (SVD) ==//
|
||||
|
||||
// lapack gesdd
|
||||
|
||||
lapack_int GesddIworkSize(int64_t m, int64_t n);
|
||||
|
||||
template <typename T>
|
||||
@ -119,6 +194,10 @@ struct ComplexGesdd {
|
||||
bool job_opt_full_matrices);
|
||||
};
|
||||
|
||||
//== Eigenvalues and eigenvectors ==//
|
||||
|
||||
// lapack syevd/heevd
|
||||
|
||||
lapack_int SyevdWorkSize(int64_t n);
|
||||
lapack_int SyevdIworkSize(int64_t n);
|
||||
|
||||
@ -145,6 +224,8 @@ struct ComplexHeevd {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// lapack geev
|
||||
|
||||
template <typename T>
|
||||
struct RealGeev {
|
||||
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a,
|
||||
@ -165,6 +246,10 @@ struct ComplexGeev {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
//== Schur Decomposition ==//
|
||||
|
||||
// lapack gees
|
||||
|
||||
template <typename T>
|
||||
struct RealGees {
|
||||
using FnType = void(char* jobvs, char* sort, bool (*select)(T, T),
|
||||
@ -186,7 +271,11 @@ struct ComplexGees {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form.
|
||||
//== Hessenberg Decomposition ==//
|
||||
//== Reduces a non-symmetric square matrix to upper Hessenberg form ==//
|
||||
|
||||
// lapack gehrd
|
||||
|
||||
template <typename T>
|
||||
struct Gehrd {
|
||||
using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a,
|
||||
@ -209,14 +298,16 @@ struct real_type<std::complex<T>> {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal
|
||||
// form.
|
||||
//== Tridiagonal Reduction ==//
|
||||
//== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==//
|
||||
|
||||
// lapack sytrd/hetrd
|
||||
|
||||
template <typename T>
|
||||
struct Sytrd {
|
||||
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
|
||||
typename real_type<T>::type* d,
|
||||
typename real_type<T>::type* e,
|
||||
T* tau, T* work,
|
||||
typename real_type<T>::type* e, T* tau, T* work,
|
||||
lapack_int* lwork, lapack_int* info);
|
||||
|
||||
static FnType* fn;
|
||||
|
@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <complex>
|
||||
#include <type_traits>
|
||||
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
|
||||
// From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but
|
||||
// a C++ user should link against LAPACK directly. This is needed when using
|
||||
// JAX-generated HLO from C++.
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
extern "C" {
|
||||
|
||||
jax::Trsm<float>::FnType strsm_;
|
||||
@ -41,10 +46,10 @@ jax::Orgqr<double>::FnType dorgqr_;
|
||||
jax::Orgqr<std::complex<float>>::FnType cungqr_;
|
||||
jax::Orgqr<std::complex<double>>::FnType zungqr_;
|
||||
|
||||
jax::Potrf<float>::FnType spotrf_;
|
||||
jax::Potrf<double>::FnType dpotrf_;
|
||||
jax::Potrf<std::complex<float>>::FnType cpotrf_;
|
||||
jax::Potrf<std::complex<double>>::FnType zpotrf_;
|
||||
jax::CholeskyFactorization<ffi::DataType::F32>::FnType spotrf_;
|
||||
jax::CholeskyFactorization<ffi::DataType::F64>::FnType dpotrf_;
|
||||
jax::CholeskyFactorization<ffi::DataType::C64>::FnType cpotrf_;
|
||||
jax::CholeskyFactorization<ffi::DataType::C128>::FnType zpotrf_;
|
||||
|
||||
jax::RealGesdd<float>::FnType sgesdd_;
|
||||
jax::RealGesdd<double>::FnType dgesdd_;
|
||||
@ -80,6 +85,27 @@ jax::Sytrd<std::complex<double>>::FnType zhetrd_;
|
||||
|
||||
namespace jax {
|
||||
|
||||
#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch"
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::F32>::FnType,
|
||||
jax::Potrf<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::F64>::FnType,
|
||||
jax::Potrf<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::C64>::FnType,
|
||||
jax::Potrf<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::C128>::FnType,
|
||||
jax::Potrf<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
|
||||
#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG
|
||||
|
||||
static auto init = []() -> int {
|
||||
AssignKernelFn<Trsm<float>>(strsm_);
|
||||
AssignKernelFn<Trsm<double>>(dtrsm_);
|
||||
@ -136,6 +162,13 @@ static auto init = []() -> int {
|
||||
AssignKernelFn<Sytrd<std::complex<float>>>(chetrd_);
|
||||
AssignKernelFn<Sytrd<std::complex<double>>>(zhetrd_);
|
||||
|
||||
// FFI Kernels
|
||||
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::F32>>(spotrf_);
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::F64>>(dpotrf_);
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::C64>>(cpotrf_);
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::C128>>(zpotrf_);
|
||||
|
||||
return 0;
|
||||
}();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user