Determine LAPACK workspace during QR Factorization Kernel runtime

PiperOrigin-RevId: 663641199
This commit is contained in:
Paweł Paruzel 2024-08-16 01:20:02 -07:00 committed by jax authors
parent 9785368c7f
commit acacf8884e
4 changed files with 19 additions and 37 deletions

View File

@ -51,13 +51,9 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr
# FFI Kernel LAPACK Workspace Size Queries
def heevd_rwork_size_ffi(n: int) -> int: ...
def heevd_work_size_ffi(n: int) -> int: ...
def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_sgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ...
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
def syevd_iwork_size_ffi(n: int) -> int: ...
def syevd_work_size_ffi(n: int) -> int: ...

View File

@ -336,18 +336,6 @@ NB_MODULE(_lapack, m) {
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace,
nb::arg("lda"), nb::arg("n"));
// FFI Kernel LAPACK Workspace Size Queries
m.def("lapack_sgeqrf_workspace_ffi",
&QrFactorization<DataType::F32>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("lapack_dgeqrf_workspace_ffi",
&QrFactorization<DataType::F64>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("lapack_cgeqrf_workspace_ffi",
&QrFactorization<DataType::C64>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("lapack_zgeqrf_workspace_ffi",
&QrFactorization<DataType::C128>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"));
m.def("lapack_sorgqr_workspace_ffi",
&OrthogonalQr<DataType::F32>::GetWorkspaceSize, nb::arg("m"),
nb::arg("n"), nb::arg("k"));

View File

@ -309,17 +309,17 @@ template struct Geqrf<std::complex<double>>;
template <ffi::DataType dtype>
ffi::Error QrFactorization<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<dtype> tau, ffi::ResultBuffer<LapackIntDtype> info,
ffi::ResultBuffer<dtype> work) {
ffi::ResultBuffer<dtype> tau, ffi::ResultBuffer<LapackIntDtype> info) {
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
auto* x_out_data = x_out->typed_data();
auto* tau_data = tau->typed_data();
auto* info_data = info->typed_data();
auto* work_data = work->typed_data();
const int64_t work_size = GetWorkspaceSize(x_rows, x_cols);
auto work_data = AllocateScratchMemory<dtype>(work_size);
CopyIfDiffBuffer(x, x_out);
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
work->dimensions().back()));
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v,
MaybeCastNoOverflow<lapack_int>(work_size));
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
auto x_leading_dim_v = x_rows_v;
@ -327,8 +327,8 @@ ffi::Error QrFactorization<dtype>::Kernel(
const int64_t x_out_step{x_rows * x_cols};
const int64_t tau_step{std::min(x_rows, x_cols)};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, work_data,
&workspace_dim_v, info_data);
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data,
work_data.get(), &workspace_dim_v, info_data);
x_out_data += x_out_step;
tau_data += tau_step;
++info_data;
@ -1701,15 +1701,14 @@ template struct Sytrd<std::complex<double>>;
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*ipiv*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_GEQRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, QrFactorization<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/))
#define JAX_CPU_DEFINE_GEQRF(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
name, QrFactorization<data_type>::Kernel, \
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
#define JAX_CPU_DEFINE_ORGQR(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \

View File

@ -192,11 +192,10 @@ struct QrFactorization {
inline static FnType* fn = nullptr;
static ::xla::ffi::Error Kernel(::xla::ffi::Buffer<dtype> x,
::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> tau,
::xla::ffi::ResultBuffer<LapackIntDtype> info,
::xla::ffi::ResultBuffer<dtype> work);
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> tau,
::xla::ffi::ResultBuffer<LapackIntDtype> info);
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols);
};