diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index faf1e9980..4c138fa85 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -223,7 +223,7 @@ def _get_states(attrs_tracked): def _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, effects, - abstracted_axes + consts, abstracted_axes, ) -> Optional[pxla.MeshExecutableFastpathData]: out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) @@ -244,7 +244,7 @@ def _get_fastpath_data( # no prng reuse checking and not (config.debug_key_reuse.value and any( hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key) - for arg in (*args_flat, *out_flat))) + for arg in (*args_flat, *out_flat, *consts))) ) if use_fastpath: @@ -306,7 +306,7 @@ def _cpp_pjit(jit_info: PjitInfo): executable = _read_most_recent_pjit_call_executable(jaxpr) maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, - jit_info.abstracted_axes) + jaxpr.consts, jit_info.abstracted_axes) return outs, maybe_fastpath_data fun = jit_info.fun @@ -1557,7 +1557,7 @@ def _pjit_call_impl(*args, jaxpr, inline=inline) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, - None) + jaxpr.consts, None) return out_flat, fastpath_data f = _get_jaxpr_as_fun( diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 94ef166e6..5287a646d 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -356,14 +356,15 @@ def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature: raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]") source(eqn.outvars[src.idx]) + all_inputs = [*jaxpr.invars, *jaxpr.constvars] return KeyReuseSignature( - *(Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars) + *(Sink(i, consumed[v]) for i, v in enumerate(all_inputs) if is_key(v) and np.any(consumed.get(v, False))), *(Source(i) for i, v in enumerate(jaxpr.outvars) - if is_key(v) and resolve_forwards(v) not in jaxpr.invars and not consumed.get(v, False)), - *(Forward(jaxpr.invars.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] + if is_key(v) and resolve_forwards(v) not in all_inputs and not consumed.get(v, False)), + *(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] for idx_out, outvar in enumerate(jaxpr.outvars) - if is_key(outvar) and resolve_forwards(outvar) in jaxpr.invars) + if is_key(outvar) and resolve_forwards(outvar) in all_inputs) ) @@ -531,23 +532,24 @@ def key_reuse_impl_rule(prim, original_rule): @wraps(original_rule) def key_reuse_impl(*args, **kwargs): if config.debug_key_reuse.value: + funcname = str(prim) + jaxpr = None + consts = [] if prim == pjit.pjit_p: funcname = "jit-compiled function" jaxpr = kwargs['jaxpr'].jaxpr + consts = kwargs['jaxpr'].consts signature = jaxpr_type_signature(jaxpr) elif prim in key_reuse_signatures: - funcname = str(prim) - jaxpr = None signature = key_reuse_signatures[prim] elif prim in key_reuse_signatures_dynamic: - funcname = str(prim) jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr signature = jaxpr_type_signature(jaxpr) else: raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}") - signature.check_signature(*args, funcname=funcname) + signature.check_signature(*args, *consts, funcname=funcname) result = original_rule(*args, **kwargs) - signature.update_consumption(args, result if prim.multiple_results else [result]) + signature.update_consumption([*args, *consts], result if prim.multiple_results else [result]) return result else: return original_rule(*args, **kwargs) diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index d98984be5..286088eeb 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -623,17 +623,26 @@ class KeyReuseEagerTest(jtu.JaxTestCase): def test_simple_reuse_nojit(self): key = jax.random.key(0) - _ = jax.random.bits(key) with jax.disable_jit(): + _ = jax.random.bits(key) with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg): _ = jax.random.bits(key) def test_simple_key_reuse_jit(self): key = jax.random.key(0) - _ = jax.random.bits(key) + _ = jax.jit(jax.random.bits)(key) with self.assertRaisesRegex(KeyReuseError, self.jit_msg): _ = jax.jit(jax.random.bits)(key) + def test_closed_over_key_reuse_jit(self): + key = jax.random.key(0) + @jax.jit + def f(): + return jax.random.uniform(key) + _ = f() + with self.assertRaisesRegex(KeyReuseError, self.jit_msg): + _ = f() + def test_key_reuse_within_jit(self): @jax.jit def f():