Avoid throwing exceptions in LAPACK kernel code

PiperOrigin-RevId: 650569943
This commit is contained in:
Paweł Paruzel 2024-07-09 03:57:11 -07:00 committed by jax authors
parent 0da9b69285
commit 4e1a66ea21

View File

@ -626,16 +626,21 @@ static ffi::Error SvdKernel(
CopyIfDiffBuffer(x, x_out);
auto x_rows_v = CastNoOverflow<lapack_int>(x_rows);
auto x_cols_v = CastNoOverflow<lapack_int>(x_cols);
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
MaybeCastNoOverflow<lapack_int>(x_rows));
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
MaybeCastNoOverflow<lapack_int>(x_cols));
auto mode_v = static_cast<char>(mode);
auto workspace_dim_v = CastNoOverflow<lapack_int>(work->dimensions.back());
ASSIGN_OR_RETURN_FFI_ERROR(
auto workspace_dim_v,
MaybeCastNoOverflow<lapack_int>(work->dimensions.back()));
auto x_leading_dim_v = x_rows_v;
auto u_leading_dim_v = x_rows_v;
auto u_dims = u->dimensions.last(2);
auto vt_dims = vt->dimensions.last(2);
auto vt_leading_dim_v = CastNoOverflow<lapack_int>(vt_dims.front());
ASSIGN_OR_RETURN_FFI_ERROR(auto vt_leading_dim_v,
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
const int64_t x_out_step{x_rows * x_cols};
const int64_t singular_values_step{singular_values->dimensions.back()};