[key reuse] handle reuse of closed-over constants

This commit is contained in:
Jake VanderPlas 2024-04-11 12:23:01 -07:00
parent 2c8051eb52
commit d5405bd92f
3 changed files with 26 additions and 15 deletions

View File

@ -223,7 +223,7 @@ def _get_states(attrs_tracked):
def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
abstracted_axes
consts, abstracted_axes,
) -> Optional[pxla.MeshExecutableFastpathData]:
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
@ -244,7 +244,7 @@ def _get_fastpath_data(
# no prng reuse checking
and not (config.debug_key_reuse.value and any(
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
for arg in (*args_flat, *out_flat)))
for arg in (*args_flat, *out_flat, *consts)))
)
if use_fastpath:
@ -306,7 +306,7 @@ def _cpp_pjit(jit_info: PjitInfo):
executable = _read_most_recent_pjit_call_executable(jaxpr)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
jit_info.abstracted_axes)
jaxpr.consts, jit_info.abstracted_axes)
return outs, maybe_fastpath_data
fun = jit_info.fun
@ -1557,7 +1557,7 @@ def _pjit_call_impl(*args, jaxpr,
inline=inline)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
None)
jaxpr.consts, None)
return out_flat, fastpath_data
f = _get_jaxpr_as_fun(

View File

@ -356,14 +356,15 @@ def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
source(eqn.outvars[src.idx])
all_inputs = [*jaxpr.invars, *jaxpr.constvars]
return KeyReuseSignature(
*(Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars)
*(Sink(i, consumed[v]) for i, v in enumerate(all_inputs)
if is_key(v) and np.any(consumed.get(v, False))),
*(Source(i) for i, v in enumerate(jaxpr.outvars)
if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)),
*(Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type]
if is_key(v) and resolve_forwards(v) not in all_inputs and not consumed.get(v, False)),
*(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type]
for idx_out, outvar in enumerate(jaxpr.outvars)
if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars)
if is_key(outvar) and resolve_forwards(outvar) in all_inputs)
)
@ -531,23 +532,24 @@ def key_reuse_impl_rule(prim, original_rule):
@wraps(original_rule)
def key_reuse_impl(*args, **kwargs):
if config.debug_key_reuse.value:
funcname = str(prim)
jaxpr = None
consts = []
if prim == pjit.pjit_p:
funcname = "jit-compiled function"
jaxpr = kwargs['jaxpr'].jaxpr
consts = kwargs['jaxpr'].consts
signature = jaxpr_type_signature(jaxpr)
elif prim in key_reuse_signatures:
funcname = str(prim)
jaxpr = None
signature = key_reuse_signatures[prim]
elif prim in key_reuse_signatures_dynamic:
funcname = str(prim)
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
signature = jaxpr_type_signature(jaxpr)
else:
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
signature.check_signature(*args, funcname=funcname)
signature.check_signature(*args, *consts, funcname=funcname)
result = original_rule(*args, **kwargs)
signature.update_consumption(args, result if prim.multiple_results else [result])
signature.update_consumption([*args, *consts], result if prim.multiple_results else [result])
return result
else:
return original_rule(*args, **kwargs)

View File

@ -623,17 +623,26 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
def test_simple_reuse_nojit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with jax.disable_jit():
_ = jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg):
_ = jax.random.bits(key)
def test_simple_key_reuse_jit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
_ = jax.jit(jax.random.bits)(key)
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = jax.jit(jax.random.bits)(key)
def test_closed_over_key_reuse_jit(self):
key = jax.random.key(0)
@jax.jit
def f():
return jax.random.uniform(key)
_ = f()
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = f()
def test_key_reuse_within_jit(self):
@jax.jit
def f():