Fix checkify caching with nested call primitives

This commit is contained in:
Sharad Vikram 2023-02-03 22:51:28 -08:00
parent f445c84ba4
commit c231171fb6
2 changed files with 4 additions and 4 deletions

View File

@ -351,9 +351,8 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
# call_jaxpr handling
call_jaxpr = params.pop('call_jaxpr')
partial_checkify = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, call_jaxpr, (), enabled_errors,
err_tree))
partial_checkify = lu.hashable_partial(lu.wrap_init(
checkify_jaxpr_flat), call_jaxpr, (), enabled_errors, err_tree)
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
partial_checkify)
@ -688,6 +687,7 @@ error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
# HOP error check rules
@weakref_lru_cache
def jaxpr_to_checkify_jaxpr(
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
*flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:

View File

@ -775,7 +775,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def test_retracing(self):
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
_ = f(3.)
with jtu.count_primitive_compiles() as count:
with jtu.count_jit_and_pmap_compiles() as count:
_ = f(3.)
self.assertEqual(count[0], 0)