mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] handle reuse of closed-over constants
This commit is contained in:
parent
2c8051eb52
commit
d5405bd92f
@ -223,7 +223,7 @@ def _get_states(attrs_tracked):
|
||||
|
||||
def _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
|
||||
abstracted_axes
|
||||
consts, abstracted_axes,
|
||||
) -> Optional[pxla.MeshExecutableFastpathData]:
|
||||
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
|
||||
@ -244,7 +244,7 @@ def _get_fastpath_data(
|
||||
# no prng reuse checking
|
||||
and not (config.debug_key_reuse.value and any(
|
||||
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
|
||||
for arg in (*args_flat, *out_flat)))
|
||||
for arg in (*args_flat, *out_flat, *consts)))
|
||||
)
|
||||
|
||||
if use_fastpath:
|
||||
@ -306,7 +306,7 @@ def _cpp_pjit(jit_info: PjitInfo):
|
||||
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
||||
maybe_fastpath_data = _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
|
||||
jit_info.abstracted_axes)
|
||||
jaxpr.consts, jit_info.abstracted_axes)
|
||||
return outs, maybe_fastpath_data
|
||||
|
||||
fun = jit_info.fun
|
||||
@ -1557,7 +1557,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
inline=inline)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
||||
None)
|
||||
jaxpr.consts, None)
|
||||
return out_flat, fastpath_data
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
|
@ -356,14 +356,15 @@ def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
|
||||
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
|
||||
source(eqn.outvars[src.idx])
|
||||
|
||||
all_inputs = [*jaxpr.invars, *jaxpr.constvars]
|
||||
return KeyReuseSignature(
|
||||
*(Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars)
|
||||
*(Sink(i, consumed[v]) for i, v in enumerate(all_inputs)
|
||||
if is_key(v) and np.any(consumed.get(v, False))),
|
||||
*(Source(i) for i, v in enumerate(jaxpr.outvars)
|
||||
if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)),
|
||||
*(Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type]
|
||||
if is_key(v) and resolve_forwards(v) not in all_inputs and not consumed.get(v, False)),
|
||||
*(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type]
|
||||
for idx_out, outvar in enumerate(jaxpr.outvars)
|
||||
if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars)
|
||||
if is_key(outvar) and resolve_forwards(outvar) in all_inputs)
|
||||
)
|
||||
|
||||
|
||||
@ -531,23 +532,24 @@ def key_reuse_impl_rule(prim, original_rule):
|
||||
@wraps(original_rule)
|
||||
def key_reuse_impl(*args, **kwargs):
|
||||
if config.debug_key_reuse.value:
|
||||
funcname = str(prim)
|
||||
jaxpr = None
|
||||
consts = []
|
||||
if prim == pjit.pjit_p:
|
||||
funcname = "jit-compiled function"
|
||||
jaxpr = kwargs['jaxpr'].jaxpr
|
||||
consts = kwargs['jaxpr'].consts
|
||||
signature = jaxpr_type_signature(jaxpr)
|
||||
elif prim in key_reuse_signatures:
|
||||
funcname = str(prim)
|
||||
jaxpr = None
|
||||
signature = key_reuse_signatures[prim]
|
||||
elif prim in key_reuse_signatures_dynamic:
|
||||
funcname = str(prim)
|
||||
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
|
||||
signature = jaxpr_type_signature(jaxpr)
|
||||
else:
|
||||
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
|
||||
signature.check_signature(*args, funcname=funcname)
|
||||
signature.check_signature(*args, *consts, funcname=funcname)
|
||||
result = original_rule(*args, **kwargs)
|
||||
signature.update_consumption(args, result if prim.multiple_results else [result])
|
||||
signature.update_consumption([*args, *consts], result if prim.multiple_results else [result])
|
||||
return result
|
||||
else:
|
||||
return original_rule(*args, **kwargs)
|
||||
|
@ -623,17 +623,26 @@ class KeyReuseEagerTest(jtu.JaxTestCase):
|
||||
|
||||
def test_simple_reuse_nojit(self):
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.bits(key)
|
||||
with jax.disable_jit():
|
||||
_ = jax.random.bits(key)
|
||||
with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg):
|
||||
_ = jax.random.bits(key)
|
||||
|
||||
def test_simple_key_reuse_jit(self):
|
||||
key = jax.random.key(0)
|
||||
_ = jax.random.bits(key)
|
||||
_ = jax.jit(jax.random.bits)(key)
|
||||
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
|
||||
_ = jax.jit(jax.random.bits)(key)
|
||||
|
||||
def test_closed_over_key_reuse_jit(self):
|
||||
key = jax.random.key(0)
|
||||
@jax.jit
|
||||
def f():
|
||||
return jax.random.uniform(key)
|
||||
_ = f()
|
||||
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
|
||||
_ = f()
|
||||
|
||||
def test_key_reuse_within_jit(self):
|
||||
@jax.jit
|
||||
def f():
|
||||
|
Loading…
x
Reference in New Issue
Block a user