mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Checkify: Fix closing over Tracer in while_loop cond_f.
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
2a1de46527
commit
ee6cbafa85
@ -805,20 +805,23 @@ error_checks[lax.scan_p] = scan_error_check
|
||||
def checkify_while_body_jaxpr(
|
||||
cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
|
||||
enabled_errors, error: Error,
|
||||
c_consts) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
|
||||
c_consts_num: int) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
|
||||
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
||||
body_f = core.jaxpr_as_fun(body_jaxpr)
|
||||
def new_body_f(*vals):
|
||||
def new_body_f(*c_consts_and_vals):
|
||||
c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
|
||||
out = body_f(*vals)
|
||||
# This checks if the next cond application will error
|
||||
_ = cond_f(*c_consts, *out)
|
||||
return out
|
||||
new_body_f_ = lu.wrap_init(new_body_f)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(new_body_f_, body_jaxpr.in_avals)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
|
||||
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
|
||||
*body_jaxpr.in_avals])
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
err_vals = map(get_shaped_aval, err_vals)
|
||||
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
|
||||
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
|
||||
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
|
||||
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
||||
return jaxpr, out_tree, error_effects
|
||||
@ -844,14 +847,16 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry)
|
||||
|
||||
_, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr,
|
||||
enabled_errors, error, c_consts)
|
||||
enabled_errors, error,
|
||||
cond_nconsts)
|
||||
# merged error!
|
||||
error = error._add_placeholder_effects(error_effects)
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr(
|
||||
cond_jaxpr, body_jaxpr, enabled_errors, error, c_consts)
|
||||
cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts)
|
||||
num_error_vals = len(err_vals)
|
||||
to_move = [False] * num_error_vals + [True] * body_nconsts + [False] * len(carry)
|
||||
to_move = ([False] * num_error_vals + [True] * cond_nconsts
|
||||
+ [True] * body_nconsts + [False] * len(carry))
|
||||
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
||||
|
||||
cond_in_flat = [*err_vals, *c_consts, *carry]
|
||||
@ -862,10 +867,10 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
|
||||
compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)
|
||||
|
||||
new_in_flat = [*c_consts, *b_consts, *err_vals, *carry]
|
||||
new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry]
|
||||
all_out_vals = lax.while_p.bind(
|
||||
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
|
||||
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
|
||||
body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr)
|
||||
# body_out_tree will have all the metadata of cond because it executes a cond!
|
||||
error, out = tree_unflatten(body_out_tree, all_out_vals)
|
||||
return error, out
|
||||
|
@ -839,6 +839,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
partial(partial, jax.lax.map)(lambda x: jnp.sin(x * y))))))
|
||||
f(jnp.array([3.])) # don't crash
|
||||
|
||||
def test_while_loop_leaks(self):
|
||||
def f(x):
|
||||
n = jnp.minimum(1, 2)
|
||||
return jax.lax.while_loop(lambda i: i < n, lambda i: i + 1, x)
|
||||
|
||||
jax.jit(checkify.checkify(f))(0) # Does not crash bc of leaked tracer.
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user