From f55de1893375538561c42abd59dd19537a46b524 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 10 May 2023 21:35:21 -0700 Subject: [PATCH] [checkify] fix closed_call_p handling Co-authored-by: Roy Frostig Co-authored-by: Sharad Vikram Co-authored-by: Yash Katariya --- jax/_src/checkify.py | 15 +++++++++++++-- jax/_src/linear_util.py | 5 ++--- tests/checkify_test.py | 7 +++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5c21aa076..273928d19 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -47,7 +47,7 @@ from jax._src.tree_util import tree_map from jax._src.tree_util import tree_unflatten from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, - unzip3, weakref_lru_cache) + unzip3, weakref_lru_cache, HashableWrapper) source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) @@ -353,8 +353,13 @@ def default_checkify_rule(primitive: core.Primitive, error: Error, # call_jaxpr handling call_jaxpr = params.pop('call_jaxpr') + if isinstance(call_jaxpr, core.ClosedJaxpr): # handle closed_call_p + jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts + else: + jaxpr, consts = call_jaxpr, () + consts_ = tuple(HashableWrapper(c) for c in consts) partial_checkify = lu.hashable_partial(lu.wrap_init( - checkify_jaxpr_flat), call_jaxpr, (), enabled_errors, err_tree) + checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree) partial_checkify, metadata = _flatten_and_get_error_metadata_thunk( partial_checkify) @@ -424,6 +429,11 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], return error, map(read_env, jaxpr.outvars) +def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors, + err_tree, *args): + consts = tuple(c.x for c in hashable_consts) + return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args) + @lu.transformation_with_aux def flatten_fun_output(*args): ans = yield args, {} @@ -985,6 +995,7 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, return out_err, out_vals error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule + def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): del debug new_error = tree_unflatten(err_tree, args) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index eb3ca8ac8..c65c83737 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -373,9 +373,8 @@ def _copy_main_traces(x): @transformation -def hashable_partial(x, *args): - ans = yield (x,) + args, {} - yield ans +def hashable_partial(*args): + yield (yield args, {}) def merge_linear_aux(aux1, aux2): diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 26cec711c..ad58c0cbc 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -831,6 +831,13 @@ class CheckifyTransformTests(jtu.JaxTestCase): h_grad = jax.grad(h_out) h_grad(0.) # doesn't crash + def test_closed_call(self): + # lots of golfing went into this test + y = jnp.array([3.14]) + summify = lambda f: lambda x: f(x).sum() + f = checkify.checkify(jax.grad(summify(jax.remat( + partial(partial, jax.lax.map)(lambda x: jnp.sin(x * y)))))) + f(jnp.array([3.])) # don't crash @jtu.with_config(jax_check_tracer_leaks=True)