mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19536 from jakevdp:key-reuse-cond
PiperOrigin-RevId: 601900128
This commit is contained in:
commit
c42305a0a9
@ -152,6 +152,8 @@ def get_jaxpr_type_signature(
|
||||
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
|
||||
|
||||
for snk in signature.sinks:
|
||||
if not 0 <= snk.idx < len(eqn.invars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
|
||||
if sink(eqn.invars[snk.idx], snk.mask):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n"
|
||||
f"eqn: {eqn}\njaxpr:\n{jaxpr}")
|
||||
@ -159,6 +161,8 @@ def get_jaxpr_type_signature(
|
||||
if not isinstance(var, core.Literal) and var not in forwards:
|
||||
source(var, True) # consumed unless in a Source.
|
||||
for src in signature.sources:
|
||||
if not 0 <= src.idx < len(eqn.outvars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
|
||||
source(eqn.outvars[src.idx])
|
||||
|
||||
return KeyReuseSignatureWithForwards(
|
||||
@ -232,7 +236,7 @@ def _cond_key_type_signature(eqn, args_consumed):
|
||||
sources[source.idx].append(source.mask)
|
||||
|
||||
combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()]
|
||||
combined_sources = [Source(i + 1, reduce(np.logical_and, m)) for i, m in sources.items()]
|
||||
combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()]
|
||||
combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in
|
||||
set.intersection(*(set(sig.forwards) for sig in signatures))]
|
||||
return KeyReuseSignatureWithForwards(combined_sinks, combined_sources, combined_forwards)
|
||||
|
@ -207,7 +207,7 @@ def _cond_key_type_signature(eqn, args_consumed):
|
||||
sources[source.idx].append(source.mask)
|
||||
|
||||
combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()]
|
||||
combined_sources = [Source(i + 1, reduce(np.logical_and, m)) for i, m in sources.items()]
|
||||
combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()]
|
||||
return KeyReuseSignature(combined_sinks, combined_sources)
|
||||
|
||||
key_reuse_signatures_dynamic[lax.cond_p] = _cond_key_type_signature
|
||||
|
@ -253,6 +253,15 @@ class KeyReuseUnitTestSimple(jtu.JaxTestCase):
|
||||
assert_consumed(key2)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
def test_cond_source(self):
|
||||
@jax.jit
|
||||
def f(flag, key):
|
||||
f1 = lambda seed, _: jax.random.key(seed)
|
||||
f2 = lambda _, key: key
|
||||
key_out = jax.lax.cond(flag, f1, f2, 0, key)
|
||||
assert_unconsumed(key_out)
|
||||
self.check_key_reuse(f, True, jax.random.key(0))
|
||||
|
||||
def test_cond_both_consumed(self):
|
||||
@jax.jit
|
||||
def f(flag, key):
|
||||
|
Loading…
x
Reference in New Issue
Block a user