mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Avoid throwing exceptions in LAPACK kernel code
PiperOrigin-RevId: 650569943
This commit is contained in:
parent
0da9b69285
commit
4e1a66ea21
@ -626,16 +626,21 @@ static ffi::Error SvdKernel(
|
||||
|
||||
CopyIfDiffBuffer(x, x_out);
|
||||
|
||||
auto x_rows_v = CastNoOverflow<lapack_int>(x_rows);
|
||||
auto x_cols_v = CastNoOverflow<lapack_int>(x_cols);
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_rows_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_rows));
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto x_cols_v,
|
||||
MaybeCastNoOverflow<lapack_int>(x_cols));
|
||||
auto mode_v = static_cast<char>(mode);
|
||||
auto workspace_dim_v = CastNoOverflow<lapack_int>(work->dimensions.back());
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(
|
||||
auto workspace_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(work->dimensions.back()));
|
||||
auto x_leading_dim_v = x_rows_v;
|
||||
auto u_leading_dim_v = x_rows_v;
|
||||
|
||||
auto u_dims = u->dimensions.last(2);
|
||||
auto vt_dims = vt->dimensions.last(2);
|
||||
auto vt_leading_dim_v = CastNoOverflow<lapack_int>(vt_dims.front());
|
||||
ASSIGN_OR_RETURN_FFI_ERROR(auto vt_leading_dim_v,
|
||||
MaybeCastNoOverflow<lapack_int>(vt_dims.front()));
|
||||
|
||||
const int64_t x_out_step{x_rows * x_cols};
|
||||
const int64_t singular_values_step{singular_values->dimensions.back()};
|
||||
|
Loading…
x
Reference in New Issue
Block a user