mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Port Householder Product to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 651691430
This commit is contained in:
parent
ff18dedf99
commit
5cce394428
@ -54,8 +54,12 @@ def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
def gesdd_iwork_size_ffi(m: int, n: int) -> int: ...
|
||||
def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_sgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ...
|
||||
|
@ -126,6 +126,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgeqrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dgeqrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cgeqrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zgeqrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_sorgqr_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dorgqr_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cungqr_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_zungqr_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_spotrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_dpotrf_ffi);
|
||||
JAX_CPU_REGISTER_HANDLER(lapack_cpotrf_ffi);
|
||||
|
@ -98,6 +98,10 @@ void GetLapackKernelsFromScipy() {
|
||||
AssignKernelFn<Orgqr<double>>(lapack_ptr("dorgqr"));
|
||||
AssignKernelFn<Orgqr<std::complex<float>>>(lapack_ptr("cungqr"));
|
||||
AssignKernelFn<Orgqr<std::complex<double>>>(lapack_ptr("zungqr"));
|
||||
AssignKernelFn<OrthogonalQr<DataType::F32>>(lapack_ptr("sorgqr"));
|
||||
AssignKernelFn<OrthogonalQr<DataType::F64>>(lapack_ptr("dorgqr"));
|
||||
AssignKernelFn<OrthogonalQr<DataType::C64>>(lapack_ptr("cungqr"));
|
||||
AssignKernelFn<OrthogonalQr<DataType::C128>>(lapack_ptr("zungqr"));
|
||||
|
||||
AssignKernelFn<Potrf<float>>(lapack_ptr("spotrf"));
|
||||
AssignKernelFn<Potrf<double>>(lapack_ptr("dpotrf"));
|
||||
@ -223,6 +227,10 @@ nb::dict Registrations() {
|
||||
dict["lapack_dgeqrf_ffi"] = EncapsulateFunction(lapack_dgeqrf_ffi);
|
||||
dict["lapack_cgeqrf_ffi"] = EncapsulateFunction(lapack_cgeqrf_ffi);
|
||||
dict["lapack_zgeqrf_ffi"] = EncapsulateFunction(lapack_zgeqrf_ffi);
|
||||
dict["lapack_sorgqr_ffi"] = EncapsulateFunction(lapack_sorgqr_ffi);
|
||||
dict["lapack_dorgqr_ffi"] = EncapsulateFunction(lapack_dorgqr_ffi);
|
||||
dict["lapack_cungqr_ffi"] = EncapsulateFunction(lapack_cungqr_ffi);
|
||||
dict["lapack_zungqr_ffi"] = EncapsulateFunction(lapack_zungqr_ffi);
|
||||
dict["lapack_spotrf_ffi"] = EncapsulateFunction(lapack_spotrf_ffi);
|
||||
dict["lapack_dpotrf_ffi"] = EncapsulateFunction(lapack_dpotrf_ffi);
|
||||
dict["lapack_cpotrf_ffi"] = EncapsulateFunction(lapack_cpotrf_ffi);
|
||||
@ -314,6 +322,18 @@ NB_MODULE(_lapack, m) {
|
||||
m.def("lapack_zgeqrf_workspace_ffi",
|
||||
&QrFactorization<DataType::C128>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_sorgqr_workspace_ffi",
|
||||
&OrthogonalQr<DataType::F32>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
m.def("lapack_dorgqr_workspace_ffi",
|
||||
&OrthogonalQr<DataType::F64>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
m.def("lapack_cungqr_workspace_ffi",
|
||||
&OrthogonalQr<DataType::C64>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
m.def("lapack_zungqr_workspace_ffi",
|
||||
&OrthogonalQr<DataType::C128>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
m.def("gesdd_iwork_size_ffi", &svd::GetIntWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("sgesdd_work_size_ffi", &svd::SVDType<DataType::F32>::GetWorkspaceSize,
|
||||
|
@ -42,6 +42,16 @@ namespace jax {
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_ORGQR(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER( \
|
||||
name, OrthogonalQr<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*tau*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_POTRF(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER( \
|
||||
name, CholeskyFactorization<data_type>::Kernel, \
|
||||
@ -92,6 +102,11 @@ JAX_CPU_DEFINE_GEQRF(lapack_dgeqrf_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_GEQRF(lapack_cgeqrf_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_GEQRF(lapack_zgeqrf_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_ORGQR(lapack_sorgqr_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_ORGQR(lapack_dorgqr_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_ORGQR(lapack_cungqr_ffi, ::xla::ffi::DataType::C64);
|
||||
JAX_CPU_DEFINE_ORGQR(lapack_zungqr_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
JAX_CPU_DEFINE_POTRF(lapack_spotrf_ffi, ::xla::ffi::DataType::F32);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_dpotrf_ffi, ::xla::ffi::DataType::F64);
|
||||
JAX_CPU_DEFINE_POTRF(lapack_cpotrf_ffi, ::xla::ffi::DataType::C64);
|
||||
@ -104,6 +119,7 @@ JAX_CPU_DEFINE_GESDD_COMPLEX(lapack_zgesdd_ffi, ::xla::ffi::DataType::C128);
|
||||
|
||||
#undef JAX_CPU_DEFINE_GETRF
|
||||
#undef JAX_CPU_DEFINE_GEQRF
|
||||
#undef JAX_CPU_DEFINE_ORGQR
|
||||
#undef JAX_CPU_DEFINE_POTRF
|
||||
#undef JAX_CPU_DEFINE_GESDD
|
||||
#undef JAX_CPU_DEFINE_GESDD_COMPLEX
|
||||
|
@ -381,6 +381,60 @@ template struct Orgqr<double>;
|
||||
template struct Orgqr<std::complex<float>>;
|
||||
template struct Orgqr<std::complex<double>>;
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error OrthogonalQr<dtype>::Kernel(ffi::Buffer<dtype> x,
|
||||
ffi::Buffer<dtype> tau,
|
||||
ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<LapackIntDtype> info,
|
||||
ffi::ResultBuffer<dtype> work) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
auto* tau_data = tau.typed_data();
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
auto* work_data = work->typed_data();
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto tau_size_v, MaybeCastNoOverflow<lapack_int>(
|
||||
tau.dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
work->dimensions().back()));
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
const int64_t tau_step{tau_size_v};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&x_rows_v, &x_cols_v, &tau_size_v, x_out_data, &x_leading_dim_v,
|
||||
tau_data, work_data, &workspace_dim_v, info_data);
|
||||
x_out_data += x_out_step;
|
||||
tau_data += tau_step;
|
||||
++info_data;
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
template <ffi::DataType dtype>
|
||||
int64_t OrthogonalQr<dtype>::GetWorkspaceSize(lapack_int x_rows,
|
||||
lapack_int x_cols,
|
||||
lapack_int tau_size) {
|
||||
ValueType optimal_size = {};
|
||||
lapack_int x_leading_dim_v = x_rows;
|
||||
lapack_int info = 0;
|
||||
lapack_int workspace_query = -1;
|
||||
fn(&x_rows, &x_cols, &tau_size, nullptr, &x_leading_dim_v, nullptr,
|
||||
&optimal_size, &workspace_query, &info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(optimal_size)) : -1;
|
||||
}
|
||||
|
||||
template struct OrthogonalQr<ffi::DataType::F32>;
|
||||
template struct OrthogonalQr<ffi::DataType::F64>;
|
||||
template struct OrthogonalQr<ffi::DataType::C64>;
|
||||
template struct OrthogonalQr<ffi::DataType::C128>;
|
||||
|
||||
//== Cholesky Factorization ==//
|
||||
|
||||
// lapack potrf
|
||||
|
@ -186,6 +186,27 @@ struct Orgqr {
|
||||
static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k);
|
||||
};
|
||||
|
||||
// FFI Kernel
|
||||
|
||||
template <::xla::ffi::DataType dtype>
|
||||
struct OrthogonalQr {
|
||||
using ValueType = ::xla::ffi::NativeType<dtype>;
|
||||
using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, ValueType* a,
|
||||
lapack_int* lda, ValueType* tau, ValueType* work,
|
||||
lapack_int* lwork, lapack_int* info);
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(::xla::ffi::Buffer<dtype> x,
|
||||
::xla::ffi::Buffer<dtype> tau,
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> work);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols,
|
||||
lapack_int tau_size);
|
||||
};
|
||||
|
||||
//== Cholesky Factorization ==//
|
||||
|
||||
// lapack potrf
|
||||
|
@ -41,10 +41,10 @@ jax::QrFactorization<ffi::DataType::F64>::FnType dgeqrf_;
|
||||
jax::QrFactorization<ffi::DataType::C64>::FnType cgeqrf_;
|
||||
jax::QrFactorization<ffi::DataType::C128>::FnType zgeqrf_;
|
||||
|
||||
jax::Orgqr<float>::FnType sorgqr_;
|
||||
jax::Orgqr<double>::FnType dorgqr_;
|
||||
jax::Orgqr<std::complex<float>>::FnType cungqr_;
|
||||
jax::Orgqr<std::complex<double>>::FnType zungqr_;
|
||||
jax::OrthogonalQr<ffi::DataType::F32>::FnType sorgqr_;
|
||||
jax::OrthogonalQr<ffi::DataType::F64>::FnType dorgqr_;
|
||||
jax::OrthogonalQr<ffi::DataType::C64>::FnType cungqr_;
|
||||
jax::OrthogonalQr<ffi::DataType::C128>::FnType zungqr_;
|
||||
|
||||
jax::CholeskyFactorization<ffi::DataType::F32>::FnType spotrf_;
|
||||
jax::CholeskyFactorization<ffi::DataType::F64>::FnType dpotrf_;
|
||||
@ -111,6 +111,18 @@ static_assert(std::is_same_v<jax::QrFactorization<ffi::DataType::C64>::FnType,
|
||||
static_assert(std::is_same_v<jax::QrFactorization<ffi::DataType::C128>::FnType,
|
||||
jax::Geqrf<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::OrthogonalQr<ffi::DataType::F32>::FnType,
|
||||
jax::Orgqr<float>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::OrthogonalQr<ffi::DataType::F64>::FnType,
|
||||
jax::Orgqr<double>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::OrthogonalQr<ffi::DataType::C64>::FnType,
|
||||
jax::Orgqr<std::complex<float>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(std::is_same_v<jax::OrthogonalQr<ffi::DataType::C128>::FnType,
|
||||
jax::Orgqr<std::complex<double>>::FnType>,
|
||||
JAX_KERNEL_FNTYPE_MISMATCH_MSG);
|
||||
static_assert(
|
||||
std::is_same_v<jax::CholeskyFactorization<ffi::DataType::F32>::FnType,
|
||||
jax::Potrf<float>::FnType>,
|
||||
@ -216,6 +228,11 @@ static auto init = []() -> int {
|
||||
AssignKernelFn<QrFactorization<ffi::DataType::C64>>(cgeqrf_);
|
||||
AssignKernelFn<QrFactorization<ffi::DataType::C128>>(zgeqrf_);
|
||||
|
||||
AssignKernelFn<OrthogonalQr<ffi::DataType::F32>>(sorgqr_);
|
||||
AssignKernelFn<OrthogonalQr<ffi::DataType::F64>>(dorgqr_);
|
||||
AssignKernelFn<OrthogonalQr<ffi::DataType::C64>>(cungqr_);
|
||||
AssignKernelFn<OrthogonalQr<ffi::DataType::C128>>(zungqr_);
|
||||
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::F32>>(spotrf_);
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::F64>>(dpotrf_);
|
||||
AssignKernelFn<CholeskyFactorization<ffi::DataType::C64>>(cpotrf_);
|
||||
|
Loading…
x
Reference in New Issue
Block a user