Port QR Factorization to XLA's FFI

PiperOrigin-RevId: 651396166
This commit is contained in:
Paweł Paruzel 2024-07-11 07:03:31 -07:00 committed by jax authors
parent b2a3edb9d5
commit 86ab50d92f
7 changed files with 134 additions and 4 deletions

View File

@ -53,5 +53,9 @@ def cgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def gesdd_iwork_size_ffi(m: int, n: int) -> int: ...
def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_sgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...

View File

@ -122,6 +122,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgeqrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgeqrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgeqrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgeqrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi);

View File

@ -89,6 +89,10 @@ void GetLapackKernelsFromScipy() {
AssignKernelFn<Geqrf<double>>(lapack_ptr("dgeqrf"));
AssignKernelFn<Geqrf<std::complex<float>>>(lapack_ptr("cgeqrf"));
AssignKernelFn<Geqrf<std::complex<double>>>(lapack_ptr("zgeqrf"));
AssignKernelFn<QrFactorization<DataType::F32>>(lapack_ptr("sgeqrf"));
AssignKernelFn<QrFactorization<DataType::F64>>(lapack_ptr("dgeqrf"));
AssignKernelFn<QrFactorization<DataType::C64>>(lapack_ptr("cgeqrf"));
AssignKernelFn<QrFactorization<DataType::C128>>(lapack_ptr("zgeqrf"));
AssignKernelFn<Orgqr<float>>(lapack_ptr("sorgqr"));
AssignKernelFn<Orgqr<double>>(lapack_ptr("dorgqr"));
@ -215,6 +219,10 @@ nb::dict Registrations() {
dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi);
dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi);
dict["lapack_zgetrf_ffi"] = EncapsulateFunction(lapack_zgetrf_ffi);
dict["lapack_sgeqrf_ffi"] = EncapsulateFunction(lapack_sgeqrf_ffi);
dict["lapack_dgeqrf_ffi"] = EncapsulateFunction(lapack_dgeqrf_ffi);
dict["lapack_cgeqrf_ffi"] = EncapsulateFunction(lapack_cgeqrf_ffi);
dict["lapack_zgeqrf_ffi"] = EncapsulateFunction(lapack_zgeqrf_ffi);
dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi);
dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi);
dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi);
@ -294,6 +302,18 @@ NB_MODULE(_lapack, m) {
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace,
nb::arg("lda"), nb::arg("n"));
// FFI Kernel LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace_ffi",
&QrFactorization<DataType::F32>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("lapack_dgeqrf_workspace_ffi",
&QrFactorization<DataType::F64>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("lapack_cgeqrf_workspace_ffi",
&QrFactorization<DataType::C64>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("lapack_zgeqrf_workspace_ffi",
&QrFactorization<DataType::C128>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("gesdd_iwork_size_ffi", &svd::GetIntWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("sgesdd_work_size_ffi", &svd::SVDType<DataType::F32>::GetWorkspaceSize,

View File

@ -32,6 +32,16 @@ namespace jax {
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*ipiv*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_GEQRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER( \
name, QrFactorization<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/))
#define JAX_CPU_DEFINE_POTRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER( \
name, CholeskyFactorization<data_type>::Kernel, \
@ -77,6 +87,11 @@ JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GETRF(lapack_cgetrf_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GETRF(lapack_zgetrf_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_GEQRF(lapack_sgeqrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GEQRF(lapack_dgeqrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_GEQRF(lapack_cgeqrf_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GEQRF(lapack_zgeqrf_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64);
@ -88,6 +103,7 @@ JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_cgesdd_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
#undef JAX_CPU_DEFINE_GETRF
#undef JAX_CPU_DEFINE_GEQRF
#undef JAX_CPU_DEFINE_POTRF
#undef JAX_CPU_DEFINE_GESDD
#undef JAX_CPU_DEFINE_GESDD_COMPLEX

View File

@ -281,6 +281,55 @@ template struct Geqrf<double>;
template struct Geqrf<std::complex<float>>;
template struct Geqrf<std::complex<double>>;
// FFI Kernel
template <ffi::DataType dtype>
ffi::Error QrFactorization<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<dtype> tau, ffi::ResultBuffer<LapackIntDtype> info,
ffi::ResultBuffer<dtype> work) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
auto* x_out_data = x_out->typed_data();
auto* tau_data = tau->typed_data();
auto* info_data = info->typed_data();
auto* work_data = work->typed_data();
CopyIfDiffBuffer(x, x_out);
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
work->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
auto x_leading_dim_v = x_rows_v;
const int64_t x_out_step{x_rows * x_cols};
const int64_t tau_step{std::min(x_rows, x_cols)};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, work_data,
&workspace_dim_v, info_data);
x_out_data += x_out_step;
tau_data += tau_step;
++info_data;
}
return ffi::Error::Success();
}
template <ffi::DataType dtype>
int64_t QrFactorization<dtype>::GetWorkspaceSize(lapack_int x_rows,
lapack_int x_cols) {
ValueType optimal_size{};
lapack_int x_leading_dim_v = x_rows;
lapack_int info = 0;
lapack_int workspace_query = -1;
fn(&x_rows, &x_cols, nullptr, &x_leading_dim_v, nullptr, &optimal_size,
&workspace_query, &info);
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
}
template struct QrFactorization<ffi::DataType::F32>;
template struct QrFactorization<ffi::DataType::F64>;
template struct QrFactorization<ffi::DataType::C64>;
template struct QrFactorization<ffi::DataType::C128>;
//== Orthogonal QR ==//
//== Computes orthogonal matrix Q from QR Decomposition ==//

View File

@ -152,6 +152,26 @@ struct Geqrf {
static int64_t Workspace(lapack_int m, lapack_int n);
};
// FFI Kernel
template <::xla::ffi::DataType dtype>
struct QrFactorization {
using ValueType = ::xla::ffi::NativeType<dtype>;
using FnType = void(lapack_int* m, lapack_int* n, ValueType* a,
lapack_int* lda, ValueType* tau, ValueType* work,
lapack_int* lwork, lapack_int* info);
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(::xla::ffi::Buffer<dtype> x,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> tau,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> work);
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols);
};
//== Orthogonal QR ==//
// lapack orgqr

View File

@ -36,10 +36,10 @@ jax::LuDecomposition<ffi::DataType::F64>::FnType dgetrf_;
jax::LuDecomposition<ffi::DataType::C64>::FnType cgetrf_;
jax::LuDecomposition<ffi::DataType::C128>::FnType zgetrf_;
jax::Geqrf<float>::FnType sgeqrf_;
jax::Geqrf<double>::FnType dgeqrf_;
jax::Geqrf<std::complex<float>>::FnType cgeqrf_;
jax::Geqrf<std::complex<double>>::FnType zgeqrf_;
jax::QrFactorization<ffi::DataType::F32>::FnType sgeqrf_;
jax::QrFactorization<ffi::DataType::F64>::FnType dgeqrf_;
jax::QrFactorization<ffi::DataType::C64>::FnType cgeqrf_;
jax::QrFactorization<ffi::DataType::C128>::FnType zgeqrf_;
jax::Orgqr<float>::FnType sorgqr_;
jax::Orgqr<double>::FnType dorgqr_;
@ -99,6 +99,18 @@ static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::C64>::FnType,
static_assert(std::is_same_v<jax::LuDecomposition<ffi::DataType::C128>::FnType,
jax::Getrf<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(std::is_same_v<jax::QrFactorization<ffi::DataType::F32>::FnType,
jax::Geqrf<float>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(std::is_same_v<jax::QrFactorization<ffi::DataType::F64>::FnType,
jax::Geqrf<double>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(std::is_same_v<jax::QrFactorization<ffi::DataType::C64>::FnType,
jax::Geqrf<std::complex<float>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(std::is_same_v<jax::QrFactorization<ffi::DataType::C128>::FnType,
jax::Geqrf<std::complex<double>>::FnType>,
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
static_assert(
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::F32>::FnType,
jax::Potrf<float>::FnType>,
@ -199,6 +211,11 @@ static auto init = []() -> int {
AssignKernelFn<LuDecomposition<ffi::DataType::C64>>(cgetrf_);
AssignKernelFn<LuDecomposition<ffi::DataType::C128>>(zgetrf_);
AssignKernelFn<QrFactorization<ffi::DataType::F32>>(sgeqrf_);
AssignKernelFn<QrFactorization<ffi::DataType::F64>>(dgeqrf_);
AssignKernelFn<QrFactorization<ffi::DataType::C64>>(cgeqrf_);
AssignKernelFn<QrFactorization<ffi::DataType::C128>>(zgeqrf_);
AssignKernelFn<CholeskyFactorization<ffi::DataType::F32>>(spotrf_);
AssignKernelFn<CholeskyFactorization<ffi::DataType::F64>>(dpotrf_);
AssignKernelFn<CholeskyFactorization<ffi::DataType::C64>>(cpotrf_);