mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[key reuse] print signature on failure
This commit is contained in:
parent
54ba49d333
commit
a56e8e87e5
@ -156,7 +156,9 @@ def get_jaxpr_type_signature(
|
||||
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
|
||||
if sink(eqn.invars[snk.idx], snk.mask):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n"
|
||||
f"eqn: {eqn}\njaxpr:\n{jaxpr}")
|
||||
f" signature: {signature}\n"
|
||||
f" eqn: {eqn}\n"
|
||||
f" jaxpr:\n{jaxpr}")
|
||||
for var in eqn.outvars:
|
||||
if not isinstance(var, core.Literal) and var not in forwards:
|
||||
source(var, True) # consumed unless in a Source.
|
||||
|
@ -122,7 +122,9 @@ def get_jaxpr_type_signature(
|
||||
for snk in signature.sinks:
|
||||
if sink(eqn.invars[snk.idx], snk.mask):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, key values {eqn.invars[snk.idx]} are already consumed.\n"
|
||||
f"eqn: {eqn}\njaxpr:\n{jaxpr}")
|
||||
f" signature: {signature}\n"
|
||||
f" eqn: {eqn}\n"
|
||||
f" jaxpr:\n{jaxpr}")
|
||||
for var in eqn.outvars:
|
||||
if not isinstance(var, core.Literal):
|
||||
source(var, True) # consumed unless in a Source.
|
||||
|
Loading…
x
Reference in New Issue
Block a user