diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index da2815d2d..d88c82f37 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -626,16 +626,21 @@ static ffi::Error SvdKernel( CopyIfDiffBuffer(x, x_out); - auto x_rows_v = CastNoOverflow(x_rows); - auto x_cols_v = CastNoOverflow(x_cols); + ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v, + MaybeCastNoOverflow(x_rows)); + ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v, + MaybeCastNoOverflow(x_cols)); auto mode_v = static_cast(mode); - auto workspace_dim_v = CastNoOverflow(work->dimensions.back()); + ASSIGN_OR_RETURN_FFI_ERROR( + auto workspace_dim_v, + MaybeCastNoOverflow(work->dimensions.back())); auto x_leading_dim_v = x_rows_v; auto u_leading_dim_v = x_rows_v; auto u_dims = u->dimensions.last(2); auto vt_dims = vt->dimensions.last(2); - auto vt_leading_dim_v = CastNoOverflow(vt_dims.front()); + ASSIGN_OR_RETURN_FFI_ERROR(auto vt_leading_dim_v, + MaybeCastNoOverflow(vt_dims.front())); const int64_t x_out_step{x_rows * x_cols}; const int64_t singular_values_step{singular_values->dimensions.back()};