mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[key reuse] improve error message using source_info_util
This commit is contained in:
parent
32fec820ed
commit
735ec63dd1
@ -29,6 +29,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.ad_checkpoint import remat_p
|
||||
from jax._src.debugging import debug_callback_p
|
||||
@ -232,40 +233,35 @@ def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
|
||||
return consumed.get(var, False)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in key_reuse_signatures:
|
||||
signature = key_reuse_signatures[eqn.primitive]
|
||||
elif eqn.primitive in key_reuse_signatures_dynamic:
|
||||
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn)
|
||||
else:
|
||||
signature = unknown_signature(eqn)
|
||||
traceback = eqn.source_info.traceback
|
||||
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
with source_info_util.user_context(traceback, name_stack=name_stack):
|
||||
if eqn.primitive in key_reuse_signatures:
|
||||
signature = key_reuse_signatures[eqn.primitive]
|
||||
elif eqn.primitive in key_reuse_signatures_dynamic:
|
||||
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn)
|
||||
else:
|
||||
signature = unknown_signature(eqn)
|
||||
|
||||
if eqn.primitive == assert_consumed_value_p:
|
||||
# This is a special case that goes beyond normal key reuse logic.
|
||||
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
|
||||
if eqn.primitive == assert_consumed_value_p:
|
||||
# This is a special case that goes beyond normal key reuse logic.
|
||||
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
|
||||
|
||||
for in_idx, out_idx in signature.forwards:
|
||||
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
|
||||
for in_idx, out_idx in signature.forwards:
|
||||
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
|
||||
|
||||
for snk in signature.sinks:
|
||||
if not 0 <= snk.idx < len(eqn.invars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
|
||||
if sink(eqn.invars[snk.idx], snk.mask):
|
||||
context = core.JaxprPpContext()
|
||||
settings = core.JaxprPpSettings()
|
||||
jaxpr_str = core.pp_jaxpr(jaxpr, context, settings).format()
|
||||
eqn_str = core.pp_eqn(eqn, context, settings).format()
|
||||
key_vals_str = core.pp_var(eqn.invars[snk.idx], context).format()
|
||||
raise KeyReuseError(f"In {eqn.primitive}, key values {key_vals_str} are already consumed.\n"
|
||||
f" signature: {signature}\n"
|
||||
f" eqn: {eqn_str}\n"
|
||||
f" jaxpr:\n{jaxpr_str}")
|
||||
for var in eqn.outvars:
|
||||
if not isinstance(var, core.Literal) and var not in forwards:
|
||||
source(var, True) # consumed unless in a Source.
|
||||
for src in signature.sources:
|
||||
if not 0 <= src.idx < len(eqn.outvars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
|
||||
source(eqn.outvars[src.idx])
|
||||
for snk in signature.sinks:
|
||||
if not 0 <= snk.idx < len(eqn.invars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
|
||||
if sink(eqn.invars[snk.idx], snk.mask):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, argument {snk.idx} is already consumed.")
|
||||
for var in eqn.outvars:
|
||||
if not isinstance(var, core.Literal) and var not in forwards:
|
||||
source(var, True) # consumed unless in a Source.
|
||||
for src in signature.sources:
|
||||
if not 0 <= src.idx < len(eqn.outvars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
|
||||
source(eqn.outvars[src.idx])
|
||||
|
||||
return KeyReuseSignature(
|
||||
sinks=[Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars)
|
||||
|
@ -334,10 +334,10 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=False)
|
||||
class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
random_bits_error = "In random_bits, key values .+ are already consumed.*"
|
||||
random_split_error = "In random_split, key values .+ are already consumed.*"
|
||||
generic_error = ".*key values .+ are already consumed.*"
|
||||
pjit_error = "In pjit, key values a are already consumed."
|
||||
random_bits_error = "In random_bits, argument [0-9]+ is already consumed.*"
|
||||
random_split_error = "In random_split, argument [0-9]+ is already consumed.*"
|
||||
generic_error = ".*argument [0-9]+ is already consumed.*"
|
||||
pjit_error = "In pjit, argument 0 is already consumed."
|
||||
|
||||
def check_key_reuse(self, f, *args):
|
||||
return _core.check_key_reuse(f, *args)
|
||||
@ -589,7 +589,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
class KeyReuseEager(jtu.JaxTestCase):
|
||||
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
|
||||
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
|
||||
traced_bits_msg = "In random_bits, key values a are already consumed."
|
||||
traced_bits_msg = "In random_bits, argument 0 is already consumed."
|
||||
|
||||
def test_simple_reuse_nojit(self):
|
||||
key = jax.random.key(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user