From a72d46c54963bf967dab240be07f55437d6ff93f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Paruzel?= Date: Thu, 22 Aug 2024 01:37:59 -0700 Subject: [PATCH] Ignore LAPACK info parameter for QR Factorization The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call. PiperOrigin-RevId: 666241215 --- jaxlib/cpu/lapack_kernels.cc | 14 ++++++-------- jaxlib/cpu/lapack_kernels.h | 3 +-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 9765b227d..c1475d1f2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -306,14 +306,14 @@ template struct Geqrf>; // FFI Kernel template -ffi::Error QrFactorization::Kernel( - ffi::Buffer x, ffi::ResultBuffer x_out, - ffi::ResultBuffer tau, ffi::ResultBuffer info) { +ffi::Error QrFactorization::Kernel(ffi::Buffer x, + ffi::ResultBuffer x_out, + ffi::ResultBuffer tau) { FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* tau_data = tau->typed_data(); - auto* info_data = info->typed_data(); + lapack_int info; const int64_t work_size = GetWorkspaceSize(x_rows, x_cols); auto work_data = AllocateScratchMemory(work_size); @@ -328,10 +328,9 @@ ffi::Error QrFactorization::Kernel( const int64_t tau_step{std::min(x_rows, x_cols)}; for (int64_t i = 0; i < batch_count; ++i) { fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, - work_data.get(), &workspace_dim_v, info_data); + work_data.get(), &workspace_dim_v, &info); x_out_data += x_out_step; tau_data += tau_step; - ++info_data; } return ffi::Error::Success(); } @@ -1713,8 +1712,7 @@ template struct Sytrd>; ::xla::ffi::Ffi::Bind() \ .Arg<::xla::ffi::Buffer>(/*x*/) \ .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/)) + .Ret<::xla::ffi::Buffer>(/*tau*/)) #define JAX_CPU_DEFINE_ORGQR(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 8abf8e22d..20823e785 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -194,8 +194,7 @@ struct QrFactorization { static ::xla::ffi::Error Kernel( ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer tau, - ::xla::ffi::ResultBuffer info); + ::xla::ffi::ResultBuffer tau); static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); };