mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
checkify: cache jaxpr formation so we don't always retrace
This commit is contained in:
parent
fcb9dfb080
commit
684846bd0f
@ -31,7 +31,7 @@ from jax._src.lax import control_flow as cf
|
||||
from jax._src.sharding import OpShardingSharding
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
unzip3)
|
||||
unzip3, weakref_lru_cache)
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
@ -383,6 +383,20 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
def get_shaped_aval(val):
|
||||
return core.raise_to_shaped(core.get_aval(val))
|
||||
|
||||
def initial_style_jaxpr(
|
||||
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
|
||||
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
|
||||
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _initial_style_jaxpr(fun, in_tree, in_avals):
|
||||
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
|
||||
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
debug = pe.debug_info(fun_, in_tree, False, 'checkify')
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
|
||||
return jaxpr, consts, out_tree()
|
||||
|
||||
|
||||
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
|
||||
error: Error, *args) -> Tuple[Error, List[core.Value]]:
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
@ -1065,16 +1079,14 @@ def checkify(f: Callable[..., Out],
|
||||
@traceback_util.api_boundary
|
||||
def checked_fun(*args, **kwargs):
|
||||
# stage:
|
||||
fun = lu.wrap_init(f)
|
||||
flat_args, in_tree = jtu.tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun(fun, in_tree)
|
||||
flat_avals = map(get_shaped_aval, flat_args)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree()
|
||||
flat_args, in_tree = tree_flatten((args, kwargs))
|
||||
in_avals = map(get_shaped_aval, flat_args)
|
||||
jaxpr_, consts, out_tree = initial_style_jaxpr(f, in_tree, in_avals)
|
||||
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
|
||||
# checkify:
|
||||
flat_args = jtu.tree_leaves((args, kwargs))
|
||||
error, out_flat = checkify_jaxpr(core.ClosedJaxpr(jaxpr, consts), errors,
|
||||
init_error, *flat_args)
|
||||
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error,
|
||||
*consts, *flat_args)
|
||||
return error, jtu.tree_unflatten(out_tree, out_flat)
|
||||
return checked_fun
|
||||
|
||||
|
@ -773,6 +773,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, _ = checked_f(jnp.ones((2, 4)))
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
def test_retracing(self):
|
||||
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
|
||||
_ = f(3.)
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
_ = f(3.)
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user