Merge pull request #16668 from LenaMartens:check-cleanup

PiperOrigin-RevId: 546898495
This commit is contained in:
jax authors 2023-07-10 09:41:50 -07:00
commit b19c63278d

View File

@ -804,12 +804,11 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)
new_in_flat = [*consts, *err_vals, *carry, *xs]
new_linear = (*[False] * len(err_vals), *linear)
tomove = ([False] * len(err_vals) + [True] * len(consts)
+ [False] * (len(carry) + len(xs)))
checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
new_in_flat = [*consts, *err_vals, *carry, *xs]
new_linear = (*[False] * len(err_vals), *linear)
err_and_out = lax.scan_p.bind(
*new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
num_consts=len(consts), num_carry=len(carry)+len(err_vals),