mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Determine LAPACK workspace during QR Factorization Kernel runtime
PiperOrigin-RevId: 663641199
This commit is contained in:
parent
9785368c7f
commit
acacf8884e
@ -51,13 +51,9 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr
|
||||
# FFI Kernel LAPACK Workspace Size Queries
|
||||
def heevd_rwork_size_ffi(n: int) -> int: ...
|
||||
def heevd_work_size_ffi(n: int) -> int: ...
|
||||
def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_sgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ...
|
||||
def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ...
|
||||
def syevd_iwork_size_ffi(n: int) -> int: ...
|
||||
def syevd_work_size_ffi(n: int) -> int: ...
|
||||
|
@ -336,18 +336,6 @@ NB_MODULE(_lapack, m) {
|
||||
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace,
|
||||
nb::arg("lda"), nb::arg("n"));
|
||||
// FFI Kernel LAPACK Workspace Size Queries
|
||||
m.def("lapack_sgeqrf_workspace_ffi",
|
||||
&QrFactorization<DataType::F32>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_dgeqrf_workspace_ffi",
|
||||
&QrFactorization<DataType::F64>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_cgeqrf_workspace_ffi",
|
||||
&QrFactorization<DataType::C64>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_zgeqrf_workspace_ffi",
|
||||
&QrFactorization<DataType::C128>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"));
|
||||
m.def("lapack_sorgqr_workspace_ffi",
|
||||
&OrthogonalQr<DataType::F32>::GetWorkspaceSize, nb::arg("m"),
|
||||
nb::arg("n"), nb::arg("k"));
|
||||
|
@ -309,17 +309,17 @@ template struct Geqrf<std::complex<double>>;
|
||||
template <ffi::DataType dtype>
|
||||
ffi::Error QrFactorization<dtype>::Kernel(
|
||||
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
|
||||
ffi::ResultBuffer<dtype> tau, ffi::ResultBuffer<LapackIntDtype> info,
|
||||
ffi::ResultBuffer<dtype> work) {
|
||||
ffi::ResultBuffer<dtype> tau, ffi::ResultBuffer<LapackIntDtype> info) {
|
||||
auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions());
|
||||
auto* x_out_data = x_out->typed_data();
|
||||
auto* tau_data = tau->typed_data();
|
||||
auto* info_data = info->typed_data();
|
||||
auto* work_data = work->typed_data();
|
||||
const int64_t work_size = GetWorkspaceSize(x_rows, x_cols);
|
||||
auto work_data = AllocateScratchMemory<dtype>(work_size);
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow<lapack_int>(
|
||||
work->dimensions().back()));
|
||||
FFI_ASSIGN_OR_RETURN(auto workspace_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work_size));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
@ -327,8 +327,8 @@ ffi::Error QrFactorization<dtype>::Kernel(
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
const int64_t tau_step{std::min(x_rows, x_cols)};
|
||||
for (int64_t i = 0; i < batch_count; ++i) {
|
||||
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, work_data,
|
||||
&workspace_dim_v, info_data);
|
||||
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data,
|
||||
work_data.get(), &workspace_dim_v, info_data);
|
||||
x_out_data += x_out_step;
|
||||
tau_data += tau_step;
|
||||
++info_data;
|
||||
@ -1701,15 +1701,14 @@ template struct Sytrd<std::complex<double>>;
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*ipiv*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_GEQRF(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, QrFactorization<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*work*/))
|
||||
#define JAX_CPU_DEFINE_GEQRF(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
name, QrFactorization<data_type>::Kernel, \
|
||||
::xla::ffi::Ffi::Bind() \
|
||||
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
|
||||
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
|
||||
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
|
||||
|
||||
#define JAX_CPU_DEFINE_ORGQR(name, data_type) \
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL( \
|
||||
|
@ -192,11 +192,10 @@ struct QrFactorization {
|
||||
|
||||
inline static FnType* fn = nullptr;
|
||||
|
||||
static ::xla::ffi::Error Kernel(::xla::ffi::Buffer<dtype> x,
|
||||
::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<dtype> tau,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info,
|
||||
::xla::ffi::ResultBuffer<dtype> work);
|
||||
static ::xla::ffi::Error Kernel(
|
||||
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
|
||||
::xla::ffi::ResultBuffer<dtype> tau,
|
||||
::xla::ffi::ResultBuffer<LapackIntDtype> info);
|
||||
|
||||
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols);
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user