Ignore LAPACK info parameter for QR Factorization

The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call.

PiperOrigin-RevId: 666241215
This commit is contained in:
Paweł Paruzel 2024-08-22 01:37:59 -07:00 committed by jax authors
parent 3713b966c2
commit a72d46c549
2 changed files with 7 additions and 10 deletions

View File

@ -306,14 +306,14 @@ template struct Geqrf<std::complex<double>>;
// FFI Kernel
template <ffi::DataType dtype>
ffi::Error QrFactorization<dtype>::Kernel(
ffi::Buffer<dtype> x, ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<dtype> tau, ffi::ResultBuffer<LapackIntDtype> info) {
ffi::Error QrFactorization<dtype>::Kernel(ffi::Buffer<dtype> x,
ffi::ResultBuffer<dtype> x_out,
ffi::ResultBuffer<dtype> tau) {
FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]),
SplitBatch2D(x.dimensions()));
auto* x_out_data = x_out->typed_data();
auto* tau_data = tau->typed_data();
auto* info_data = info->typed_data();
lapack_int info;
const int64_t work_size = GetWorkspaceSize(x_rows, x_cols);
auto work_data = AllocateScratchMemory<dtype>(work_size);
@ -328,10 +328,9 @@ ffi::Error QrFactorization<dtype>::Kernel(
const int64_t tau_step{std::min(x_rows, x_cols)};
for (int64_t i = 0; i < batch_count; ++i) {
fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data,
work_data.get(), &workspace_dim_v, info_data);
work_data.get(), &workspace_dim_v, &info);
x_out_data += x_out_step;
tau_data += tau_step;
++info_data;
}
return ffi::Error::Success();
}
@ -1713,8 +1712,7 @@ template struct Sytrd<std::complex<double>>;
::xla::ffi::Ffi::Bind() \
.Arg<::xla::ffi::Buffer<data_type>>(/*x*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*x_out*/) \
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/) \
.Ret<::xla::ffi::Buffer<LapackIntDtype>>(/*info*/))
.Ret<::xla::ffi::Buffer<data_type>>(/*tau*/))
#define JAX_CPU_DEFINE_ORGQR(name, data_type) \
XLA_FFI_DEFINE_HANDLER_SYMBOL( \

View File

@ -194,8 +194,7 @@ struct QrFactorization {
static ::xla::ffi::Error Kernel(
::xla::ffi::Buffer<dtype> x, ::xla::ffi::ResultBuffer<dtype> x_out,
::xla::ffi::ResultBuffer<dtype> tau,
::xla::ffi::ResultBuffer<LapackIntDtype> info);
::xla::ffi::ResultBuffer<dtype> tau);
static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols);
};