[key reuse] improve error message using source_info_util

This commit is contained in:
Jake VanderPlas 2024-03-05 11:02:39 -08:00
parent 32fec820ed
commit 735ec63dd1
2 changed files with 32 additions and 36 deletions

View File

@ -29,6 +29,7 @@ from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
from jax._src import random
from jax._src import source_info_util
from jax._src import util
from jax._src.ad_checkpoint import remat_p
from jax._src.debugging import debug_callback_p
@ -232,40 +233,35 @@ def get_jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
return consumed.get(var, False)
for eqn in jaxpr.eqns:
if eqn.primitive in key_reuse_signatures:
signature = key_reuse_signatures[eqn.primitive]
elif eqn.primitive in key_reuse_signatures_dynamic:
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn)
else:
signature = unknown_signature(eqn)
traceback = eqn.source_info.traceback
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
with source_info_util.user_context(traceback, name_stack=name_stack):
if eqn.primitive in key_reuse_signatures:
signature = key_reuse_signatures[eqn.primitive]
elif eqn.primitive in key_reuse_signatures_dynamic:
signature = key_reuse_signatures_dynamic[eqn.primitive](eqn)
else:
signature = unknown_signature(eqn)
if eqn.primitive == assert_consumed_value_p:
# This is a special case that goes beyond normal key reuse logic.
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
if eqn.primitive == assert_consumed_value_p:
# This is a special case that goes beyond normal key reuse logic.
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
for in_idx, out_idx in signature.forwards:
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
for in_idx, out_idx in signature.forwards:
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
for snk in signature.sinks:
if not 0 <= snk.idx < len(eqn.invars):
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
if sink(eqn.invars[snk.idx], snk.mask):
context = core.JaxprPpContext()
settings = core.JaxprPpSettings()
jaxpr_str = core.pp_jaxpr(jaxpr, context, settings).format()
eqn_str = core.pp_eqn(eqn, context, settings).format()
key_vals_str = core.pp_var(eqn.invars[snk.idx], context).format()
raise KeyReuseError(f"In {eqn.primitive}, key values {key_vals_str} are already consumed.\n"
f" signature: {signature}\n"
f" eqn: {eqn_str}\n"
f" jaxpr:\n{jaxpr_str}")
for var in eqn.outvars:
if not isinstance(var, core.Literal) and var not in forwards:
source(var, True) # consumed unless in a Source.
for src in signature.sources:
if not 0 <= src.idx < len(eqn.outvars):
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
source(eqn.outvars[src.idx])
for snk in signature.sinks:
if not 0 <= snk.idx < len(eqn.invars):
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
if sink(eqn.invars[snk.idx], snk.mask):
raise KeyReuseError(f"In {eqn.primitive}, argument {snk.idx} is already consumed.")
for var in eqn.outvars:
if not isinstance(var, core.Literal) and var not in forwards:
source(var, True) # consumed unless in a Source.
for src in signature.sources:
if not 0 <= src.idx < len(eqn.outvars):
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
source(eqn.outvars[src.idx])
return KeyReuseSignature(
sinks=[Sink(i, consumed[v]) for i, v in enumerate(jaxpr.invars)

View File

@ -334,10 +334,10 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
@jtu.with_config(jax_enable_key_reuse_checks=False)
class KeyReuseIntegrationTest(jtu.JaxTestCase):
random_bits_error = "In random_bits, key values .+ are already consumed.*"
random_split_error = "In random_split, key values .+ are already consumed.*"
generic_error = ".*key values .+ are already consumed.*"
pjit_error = "In pjit, key values a are already consumed."
random_bits_error = "In random_bits, argument [0-9]+ is already consumed.*"
random_split_error = "In random_split, argument [0-9]+ is already consumed.*"
generic_error = ".*argument [0-9]+ is already consumed.*"
pjit_error = "In pjit, argument 0 is already consumed."
def check_key_reuse(self, f, *args):
return _core.check_key_reuse(f, *args)
@ -589,7 +589,7 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
class KeyReuseEager(jtu.JaxTestCase):
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
traced_bits_msg = "In random_bits, key values a are already consumed."
traced_bits_msg = "In random_bits, argument 0 is already consumed."
def test_simple_reuse_nojit(self):
key = jax.random.key(0)