[checkify] fix closed_call_p handling

Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
This commit is contained in:
Matthew Johnson 2023-05-10 21:35:21 -07:00
parent 261ff9e9ed
commit f55de18933
3 changed files with 22 additions and 5 deletions

View File

@ -47,7 +47,7 @@ from jax._src.tree_util import tree_map
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache)
unzip3, weakref_lru_cache, HashableWrapper)
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
@ -353,8 +353,13 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
# call_jaxpr handling
call_jaxpr = params.pop('call_jaxpr')
if isinstance(call_jaxpr, core.ClosedJaxpr): # handle closed_call_p
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
else:
jaxpr, consts = call_jaxpr, ()
consts_ = tuple(HashableWrapper(c) for c in consts)
partial_checkify = lu.hashable_partial(lu.wrap_init(
checkify_jaxpr_flat), call_jaxpr, (), enabled_errors, err_tree)
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)
@ -424,6 +429,11 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
return error, map(read_env, jaxpr.outvars)
def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
err_tree, *args):
consts = tuple(c.x for c in hashable_consts)
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)
@lu.transformation_with_aux
def flatten_fun_output(*args):
ans = yield args, {}
@ -985,6 +995,7 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
return out_err, out_vals
error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
del debug
new_error = tree_unflatten(err_tree, args)

View File

@ -373,9 +373,8 @@ def _copy_main_traces(x):
@transformation
def hashable_partial(x, *args):
ans = yield (x,) + args, {}
yield ans
def hashable_partial(*args):
yield (yield args, {})
def merge_linear_aux(aux1, aux2):

View File

@ -831,6 +831,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
h_grad = jax.grad(h_out)
h_grad(0.) # doesn't crash
def test_closed_call(self):
# lots of golfing went into this test
y = jnp.array([3.14])
summify = lambda f: lambda x: f(x).sum()
f = checkify.checkify(jax.grad(summify(jax.remat(
partial(partial, jax.lax.map)(lambda x: jnp.sin(x * y))))))
f(jnp.array([3.])) # don't crash
@jtu.with_config(jax_check_tracer_leaks=True)