Merge pull request #19536 from jakevdp:key-reuse-cond

PiperOrigin-RevId: 601900128
This commit is contained in:
jax authors 2024-01-26 16:43:44 -08:00
commit c42305a0a9
3 changed files with 15 additions and 2 deletions

View File

@ -152,6 +152,8 @@ def get_jaxpr_type_signature(
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
for snk in signature.sinks:
if not 0 <= snk.idx < len(eqn.invars):
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
if sink(eqn.invars[snk.idx], snk.mask):
raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n"
f"eqn: {eqn}\njaxpr:\n{jaxpr}")
@ -159,6 +161,8 @@ def get_jaxpr_type_signature(
if not isinstance(var, core.Literal) and var not in forwards:
source(var, True) # consumed unless in a Source.
for src in signature.sources:
if not 0 <= src.idx < len(eqn.outvars):
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
source(eqn.outvars[src.idx])
return KeyReuseSignatureWithForwards(
@ -232,7 +236,7 @@ def _cond_key_type_signature(eqn, args_consumed):
sources[source.idx].append(source.mask)
combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()]
combined_sources = [Source(i + 1, reduce(np.logical_and, m)) for i, m in sources.items()]
combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()]
combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in
set.intersection(*(set(sig.forwards) for sig in signatures))]
return KeyReuseSignatureWithForwards(combined_sinks, combined_sources, combined_forwards)

View File

@ -207,7 +207,7 @@ def _cond_key_type_signature(eqn, args_consumed):
sources[source.idx].append(source.mask)
combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()]
combined_sources = [Source(i + 1, reduce(np.logical_and, m)) for i, m in sources.items()]
combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()]
return KeyReuseSignature(combined_sinks, combined_sources)
key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature

View File

@ -253,6 +253,15 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase):
assert_consumed(key2)
self.check_key_reuse(f, jax.random.key(0))
def test_cond_source(self):
@jax.jit
def f(flag, key):
f1 = lambda seed, _: jax.random.key(seed)
f2 = lambda _, key: key
key_out = jax.lax.cond(flag, f1, f2, 0, key)
assert_unconsumed(key_out)
self.check_key_reuse(f, True, jax.random.key(0))
def test_cond_both_consumed(self):
@jax.jit
def f(flag, key):