From 684846bd0fb01af64a168d0209a110ebca941a67 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 1 Feb 2023 10:19:47 -0800 Subject: [PATCH] checkify: cache jaxpr formation so we don't always retrace --- jax/_src/checkify.py | 30 +++++++++++++++++++++--------- tests/checkify_test.py | 7 +++++++ 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 57999cfbb..be3283e59 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -31,7 +31,7 @@ from jax._src.lax import control_flow as cf from jax._src.sharding import OpShardingSharding from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, - unzip3) + unzip3, weakref_lru_cache) from jax.api_util import flatten_fun from jax.experimental import maps from jax.experimental import pjit @@ -383,6 +383,20 @@ def default_checkify_rule(primitive: core.Primitive, error: Error, def get_shaped_aval(val): return core.raise_to_shaped(core.get_aval(val)) +def initial_style_jaxpr( + fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue] + ) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]: + return _initial_style_jaxpr(fun, in_tree, tuple(in_avals)) + +@weakref_lru_cache +def _initial_style_jaxpr(fun, in_tree, in_avals): + # like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs + fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) + debug = pe.debug_info(fun_, in_tree, False, 'checkify') + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug) + return jaxpr, consts, out_tree() + + def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors, error: Error, *args) -> Tuple[Error, List[core.Value]]: err_vals, err_tree = jtu.tree_flatten(error) @@ -1065,16 +1079,14 @@ def checkify(f: Callable[..., Out], @traceback_util.api_boundary def checked_fun(*args, **kwargs): # stage: - fun = lu.wrap_init(f) - flat_args, in_tree = jtu.tree_flatten((args, kwargs)) - flat_fun, out_tree = flatten_fun(fun, in_tree) - flat_avals = map(get_shaped_aval, flat_args) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree() + flat_args, in_tree = tree_flatten((args, kwargs)) + in_avals = map(get_shaped_aval, flat_args) + jaxpr_, consts, out_tree = initial_style_jaxpr(f, in_tree, in_avals) + jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) # checkify: flat_args = jtu.tree_leaves((args, kwargs)) - error, out_flat = checkify_jaxpr(core.ClosedJaxpr(jaxpr, consts), errors, - init_error, *flat_args) + error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, + *consts, *flat_args) return error, jtu.tree_unflatten(out_tree, out_flat) return checked_fun diff --git a/tests/checkify_test.py b/tests/checkify_test.py index a9aac55c1..02be822cf 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -773,6 +773,13 @@ class CheckifyTransformTests(jtu.JaxTestCase): err, _ = checked_f(jnp.ones((2, 4))) self.assertIsNone(err.get()) + def test_retracing(self): + f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2)) + _ = f(3.) + with jtu.count_primitive_compiles() as count: + _ = f(3.) + self.assertEqual(count[0], 0) + @jtu.with_config(jax_check_tracer_leaks=True) class AssertPrimitiveTests(jtu.JaxTestCase):