[key reuse] print signature on failure

This commit is contained in:
Jake VanderPlas 2024-01-30 12:46:15 -08:00
parent 54ba49d333
commit a56e8e87e5
2 changed files with 6 additions and 2 deletions

View File

@ -156,7 +156,9 @@ def get_jaxpr_type_signature(
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
if sink(eqn.invars[snk.idx], snk.mask):
raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n"
f"eqn: {eqn}\njaxpr:\n{jaxpr}")
f" signature: {signature}\n"
f" eqn: {eqn}\n"
f" jaxpr:\n{jaxpr}")
for var in eqn.outvars:
if not isinstance(var, core.Literal) and var not in forwards:
source(var, True) # consumed unless in a Source.

View File

@ -122,7 +122,9 @@ def get_jaxpr_type_signature(
for snk in signature.sinks:
if sink(eqn.invars[snk.idx], snk.mask):
raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n"
f"eqn: {eqn}\njaxpr:\n{jaxpr}")
f" signature: {signature}\n"
f" eqn: {eqn}\n"
f" jaxpr:\n{jaxpr}")
for var in eqn.outvars:
if not isinstance(var, core.Literal):
source(var, True) # consumed unless in a Source.