mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[checkify] fix closed_call_p handling
Co-authored-by: Roy Frostig <frostig@google.com> Co-authored-by: Sharad Vikram <sharadmv@google.com> Co-authored-by: Yash Katariya <yashkatariya@google.com>
This commit is contained in:
parent
261ff9e9ed
commit
f55de18933
@ -47,7 +47,7 @@ from jax._src.tree_util import tree_map
|
||||
from jax._src.tree_util import tree_unflatten
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
unzip3, weakref_lru_cache)
|
||||
unzip3, weakref_lru_cache, HashableWrapper)
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -353,8 +353,13 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
|
||||
# call_jaxpr handling
|
||||
call_jaxpr = params.pop('call_jaxpr')
|
||||
if isinstance(call_jaxpr, core.ClosedJaxpr): # handle closed_call_p
|
||||
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
||||
else:
|
||||
jaxpr, consts = call_jaxpr, ()
|
||||
consts_ = tuple(HashableWrapper(c) for c in consts)
|
||||
partial_checkify = lu.hashable_partial(lu.wrap_init(
|
||||
checkify_jaxpr_flat), call_jaxpr, (), enabled_errors, err_tree)
|
||||
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
|
||||
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
|
||||
partial_checkify)
|
||||
|
||||
@ -424,6 +429,11 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
||||
|
||||
return error, map(read_env, jaxpr.outvars)
|
||||
|
||||
def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
|
||||
err_tree, *args):
|
||||
consts = tuple(c.x for c in hashable_consts)
|
||||
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_output(*args):
|
||||
ans = yield args, {}
|
||||
@ -985,6 +995,7 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
|
||||
return out_err, out_vals
|
||||
error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule
|
||||
|
||||
|
||||
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
|
||||
del debug
|
||||
new_error = tree_unflatten(err_tree, args)
|
||||
|
@ -373,9 +373,8 @@ def _copy_main_traces(x):
|
||||
|
||||
|
||||
@transformation
|
||||
def hashable_partial(x, *args):
|
||||
ans = yield (x,) + args, {}
|
||||
yield ans
|
||||
def hashable_partial(*args):
|
||||
yield (yield args, {})
|
||||
|
||||
|
||||
def merge_linear_aux(aux1, aux2):
|
||||
|
@ -831,6 +831,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
h_grad = jax.grad(h_out)
|
||||
h_grad(0.) # doesn't crash
|
||||
|
||||
def test_closed_call(self):
|
||||
# lots of golfing went into this test
|
||||
y = jnp.array([3.14])
|
||||
summify = lambda f: lambda x: f(x).sum()
|
||||
f = checkify.checkify(jax.grad(summify(jax.remat(
|
||||
partial(partial, jax.lax.map)(lambda x: jnp.sin(x * y))))))
|
||||
f(jnp.array([3.])) # don't crash
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user