Checkify: Fix closing over Tracer in while_loop cond_f.

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
lenamartens 2023-05-17 17:26:07 +01:00
parent 2a1de46527
commit ee6cbafa85
2 changed files with 22 additions and 10 deletions

View File

@ -805,20 +805,23 @@ error_checks[lax.scan_p] = scan_error_check
def checkify_while_body_jaxpr(
cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
enabled_errors, error: Error,
c_consts) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
c_consts_num: int) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*vals):
def new_body_f(*c_consts_and_vals):
c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
out = body_f(*vals)
# This checks if the next cond application will error
_ = cond_f(*c_consts, *out)
return out
new_body_f_ = lu.wrap_init(new_body_f)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(new_body_f_, body_jaxpr.in_avals)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
*body_jaxpr.in_avals])
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
err_vals, err_tree = jtu.tree_flatten(error)
err_vals = map(get_shaped_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
return jaxpr, out_tree, error_effects
@ -844,14 +847,16 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry)
_, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr,
enabled_errors, error, c_consts)
enabled_errors, error,
cond_nconsts)
# merged error!
error = error._add_placeholder_effects(error_effects)
err_vals, err_tree = jtu.tree_flatten(error)
checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, enabled_errors, error, c_consts)
cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts)
num_error_vals = len(err_vals)
to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry)
to_move = ([False] * num_error_vals + [True] * cond_nconsts
+ [True] * body_nconsts + [False] * len(carry))
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
cond_in_flat = [*err_vals, *c_consts, *carry]
@ -862,10 +867,10 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
new_in_flat = [*c_consts, *b_consts, *err_vals, *carry]
new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry]
all_out_vals = lax.while_p.bind(
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr)
# body_out_tree will have all the metadata of cond because it executes a cond!
error, out = tree_unflatten(body_out_tree, all_out_vals)
return error, out

View File

@ -839,6 +839,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
partial(partial, jax.lax.map)(lambda x: jnp.sin(x * y))))))
f(jnp.array([3.])) # don't crash
def test_while_loop_leaks(self):
def f(x):
n = jnp.minimum(1, 2)
return jax.lax.while_loop(lambda i: i < n, lambda i: i + 1, x)
jax.jit(checkify.checkify(f))(0) # Does not crash bc of leaked tracer.
@jtu.with_config(jax_check_tracer_leaks=True)
class AssertPrimitiveTests(jtu.JaxTestCase):