checkify: cache jaxpr formation so we don't always retrace

This commit is contained in:
Matthew Johnson 2023-02-01 10:19:47 -08:00
parent fcb9dfb080
commit 684846bd0f
2 changed files with 28 additions and 9 deletions

View File

@ -31,7 +31,7 @@ from jax._src.lax import control_flow as cf
from jax._src.sharding import OpShardingSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3)
unzip3, weakref_lru_cache)
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
@ -383,6 +383,20 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
def get_shaped_aval(val):
return core.raise_to_shaped(core.get_aval(val))
def initial_style_jaxpr(
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
@weakref_lru_cache
def _initial_style_jaxpr(fun, in_tree, in_avals):
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun_, in_tree, False, 'checkify')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree()
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
error: Error, *args) -> Tuple[Error, List[core.Value]]:
err_vals, err_tree = jtu.tree_flatten(error)
@ -1065,16 +1079,14 @@ def checkify(f: Callable[..., Out],
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
# stage:
fun = lu.wrap_init(f)
flat_args, in_tree = jtu.tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(fun, in_tree)
flat_avals = map(get_shaped_aval, flat_args)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree()
flat_args, in_tree = tree_flatten((args, kwargs))
in_avals = map(get_shaped_aval, flat_args)
jaxpr_, consts, out_tree = initial_style_jaxpr(f, in_tree, in_avals)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
flat_args = jtu.tree_leaves((args, kwargs))
error, out_flat = checkify_jaxpr(core.ClosedJaxpr(jaxpr, consts), errors,
init_error, *flat_args)
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error,
*consts, *flat_args)
return error, jtu.tree_unflatten(out_tree, out_flat)
return checked_fun

View File

@ -773,6 +773,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, _ = checked_f(jnp.ones((2, 4)))
self.assertIsNone(err.get())
def test_retracing(self):
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
_ = f(3.)
with jtu.count_primitive_compiles() as count:
_ = f(3.)
self.assertEqual(count[0], 0)
@jtu.with_config(jax_check_tracer_leaks=True)
class AssertPrimitiveTests(jtu.JaxTestCase):