From ee6cbafa85d995082d53c9600a018bbea9b8a9c6 Mon Sep 17 00:00:00 2001 From: lenamartens Date: Wed, 17 May 2023 17:26:07 +0100 Subject: [PATCH] Checkify: Fix closing over Tracer in while_loop cond_f. Co-authored-by: Matthew Johnson --- jax/_src/checkify.py | 25 +++++++++++++++---------- tests/checkify_test.py | 7 +++++++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 273928d19..3d66b8bb6 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -805,20 +805,23 @@ error_checks[lax.scan_p] = scan_error_check def checkify_while_body_jaxpr( cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr, enabled_errors, error: Error, - c_consts) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]: + c_consts_num: int) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]: cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) - def new_body_f(*vals): + def new_body_f(*c_consts_and_vals): + c_consts, vals = split_list(c_consts_and_vals, [c_consts_num]) out = body_f(*vals) # This checks if the next cond application will error _ = cond_f(*c_consts, *out) return out new_body_f_ = lu.wrap_init(new_body_f) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(new_body_f_, body_jaxpr.in_avals) - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + c_consts_avals = cond_jaxpr.in_avals[:c_consts_num] + jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals, + *body_jaxpr.in_avals]) + closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) err_vals, err_tree = jtu.tree_flatten(error) err_vals = map(get_shaped_aval, err_vals) - flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals] + flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals] jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr( closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) return jaxpr, out_tree, error_effects @@ -844,14 +847,16 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry) _, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, - enabled_errors, error, c_consts) + enabled_errors, error, + cond_nconsts) # merged error! error = error._add_placeholder_effects(error_effects) err_vals, err_tree = jtu.tree_flatten(error) checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr( - cond_jaxpr, body_jaxpr, enabled_errors, error, c_consts) + cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts) num_error_vals = len(err_vals) - to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry) + to_move = ([False] * num_error_vals + [True] * cond_nconsts + + [True] * body_nconsts + [False] * len(carry)) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) cond_in_flat = [*err_vals, *c_consts, *carry] @@ -862,10 +867,10 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry) compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) - new_in_flat = [*c_consts, *b_consts, *err_vals, *carry] + new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry] all_out_vals = lax.while_p.bind( *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, - body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) + body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr) # body_out_tree will have all the metadata of cond because it executes a cond! error, out = tree_unflatten(body_out_tree, all_out_vals) return error, out diff --git a/tests/checkify_test.py b/tests/checkify_test.py index ad58c0cbc..f939c1e9f 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -839,6 +839,13 @@ class CheckifyTransformTests(jtu.JaxTestCase): partial(partial, jax.lax.map)(lambda x: jnp.sin(x * y)))))) f(jnp.array([3.])) # don't crash + def test_while_loop_leaks(self): + def f(x): + n = jnp.minimum(1, 2) + return jax.lax.while_loop(lambda i: i < n, lambda i: i + 1, x) + + jax.jit(checkify.checkify(f))(0) # Does not crash bc of leaked tracer. + @jtu.with_config(jax_check_tracer_leaks=True) class AssertPrimitiveTests(jtu.JaxTestCase):