From c231171fb6616bdd88175460ab19701bfac8a80e Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 3 Feb 2023 22:51:28 -0800 Subject: [PATCH] Fix checkify caching with nested call primitives --- jax/_src/checkify.py | 6 +++--- tests/checkify_test.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5b77625b4..4ed67669c 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -351,9 +351,8 @@ def default_checkify_rule(primitive: core.Primitive, error: Error, # call_jaxpr handling call_jaxpr = params.pop('call_jaxpr') - partial_checkify = lu.wrap_init( - functools.partial(checkify_jaxpr_flat, call_jaxpr, (), enabled_errors, - err_tree)) + partial_checkify = lu.hashable_partial(lu.wrap_init( + checkify_jaxpr_flat), call_jaxpr, (), enabled_errors, err_tree) partial_checkify, metadata = _flatten_and_get_error_metadata_thunk( partial_checkify) @@ -688,6 +687,7 @@ error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check, # HOP error check rules +@weakref_lru_cache def jaxpr_to_checkify_jaxpr( jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef, *flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index d9affbc7d..2448ad9d4 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -775,7 +775,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): def test_retracing(self): f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2)) _ = f(3.) - with jtu.count_primitive_compiles() as count: + with jtu.count_jit_and_pmap_compiles() as count: _ = f(3.) self.assertEqual(count[0], 0)