mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16668 from LenaMartens:check-cleanup
PiperOrigin-RevId: 546898495
This commit is contained in:
commit
b19c63278d
@ -804,12 +804,11 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
||||
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
||||
err_tree, *new_in_aval)
|
||||
|
||||
new_in_flat = [*consts, *err_vals, *carry, *xs]
|
||||
new_linear = (*[False] * len(err_vals), *linear)
|
||||
tomove = ([False] * len(err_vals) + [True] * len(consts)
|
||||
+ [False] * (len(carry) + len(xs)))
|
||||
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
|
||||
new_in_flat = [*consts, *err_vals, *carry, *xs]
|
||||
new_linear = (*[False] * len(err_vals), *linear)
|
||||
err_and_out = lax.scan_p.bind(
|
||||
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
|
||||
num_consts=len(consts), num_carry=len(carry)+len(err_vals),
|
||||
|
Loading…
x
Reference in New Issue
Block a user