mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix jaxpr type checking error messages
The pretty-printing changes a few months ago defined variable names based on the state in JaxprPpContext instances. But that meant incorrect variable names could be printed in jaxpr type checking error messages. This commit correctly threads through the context so as to provide error messages with coherent variable names.
This commit is contained in:
parent
933971dbff
commit
3548e023ec
41
jax/core.py
41
jax/core.py
@ -920,7 +920,7 @@ class AbstractValue:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
raise NotImplementedError("must override")
|
||||
return str(self)
|
||||
|
||||
class Bot(AbstractValue): pass
|
||||
|
||||
@ -2022,34 +2022,38 @@ def check_jaxpr(jaxpr: Jaxpr):
|
||||
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
|
||||
otherwise.
|
||||
"""
|
||||
ctx = JaxprPpContext()
|
||||
try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names
|
||||
except: pass
|
||||
|
||||
try:
|
||||
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
|
||||
_check_jaxpr(ctx, jaxpr, [v.aval for v in jaxpr.invars])
|
||||
except JaxprTypeError as e:
|
||||
if len(e.args) == 2:
|
||||
msg, eqn_idx = e.args
|
||||
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10,
|
||||
JaxprPpContext()))
|
||||
msg, eqnidx = e.args
|
||||
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx))
|
||||
else:
|
||||
msg, = e.args
|
||||
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, JaxprPpContext()))
|
||||
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, ctx))
|
||||
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
|
||||
raise JaxprTypeError(msg) from None
|
||||
|
||||
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
||||
def _check_jaxpr(ctx: 'JaxprPpContext', jaxpr: Jaxpr,
|
||||
in_avals: Sequence[AbstractValue]) -> None:
|
||||
|
||||
def read(v: Atom) -> AbstractValue:
|
||||
if isinstance(v, Literal):
|
||||
return raise_to_shaped(get_aval(v.val))
|
||||
else:
|
||||
typecheck_assert(v in env, f"Variable '{v}' not defined")
|
||||
typecheck_assert(v in env, f"Variable '{pp_var(v, ctx)}' not defined")
|
||||
return env[v]
|
||||
|
||||
def write(v: Var, a: AbstractValue) -> None:
|
||||
typecheck_assert(v not in env, f"Variable '{v}' already bound")
|
||||
typecheck_assert(v not in env, f"Variable '{pp_var(v, ctx)}' already bound")
|
||||
if not isinstance(v, DropVar):
|
||||
typecheck_assert(typecompat(v.aval, a),
|
||||
f"Variable '{v}' inconsistently typed as {a}, "
|
||||
f"bound as {v.aval}")
|
||||
f"Variable '{pp_var(v, ctx)}' inconsistently typed as "
|
||||
f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}")
|
||||
env[v] = a
|
||||
|
||||
env : Dict[Var, AbstractValue] = {}
|
||||
@ -2069,17 +2073,16 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
|
||||
if out_avals is None:
|
||||
out_avals = [v.aval for v in eqn.outvars]
|
||||
elif prim.call_primitive:
|
||||
out_avals = check_call(prim, in_avals, eqn.params)
|
||||
out_avals = check_call(ctx, prim, in_avals, eqn.params)
|
||||
elif prim.map_primitive:
|
||||
out_avals = check_map(prim, in_avals, eqn.params)
|
||||
out_avals = check_map(ctx, prim, in_avals, eqn.params)
|
||||
else:
|
||||
out_avals = check_eqn(prim, in_avals, eqn.params)
|
||||
map(write, eqn.outvars, out_avals)
|
||||
except JaxprTypeError as e:
|
||||
msg, = e.args
|
||||
src = source_info_util.summarize(eqn.source_info)
|
||||
msg = "\n\n".join([msg, "in equation:",
|
||||
str(pp.nest(2, pp_eqn(eqn, JaxprPpContext()))),
|
||||
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx))),
|
||||
f"from source: {src}"])
|
||||
raise JaxprTypeError(msg, eqn_idx) from None
|
||||
|
||||
@ -2094,7 +2097,7 @@ def check_eqn(prim, in_avals, params):
|
||||
out_avals = [out_avals]
|
||||
return out_avals
|
||||
|
||||
def check_call(prim, in_avals, params):
|
||||
def check_call(ctx, prim, in_avals, params):
|
||||
typecheck_assert("call_jaxpr" in params,
|
||||
f"Call primitive {prim} missing 'call_jaxpr' parameter")
|
||||
call_jaxpr = params["call_jaxpr"]
|
||||
@ -2110,12 +2113,12 @@ def check_call(prim, in_avals, params):
|
||||
f"Call primitive {prim} passes operand {in_aval} "
|
||||
f"to jaxpr expecting {binder_aval}")
|
||||
|
||||
_check_jaxpr(call_jaxpr, in_avals)
|
||||
_check_jaxpr(ctx, call_jaxpr, in_avals)
|
||||
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
return out_avals
|
||||
|
||||
def check_map(prim, in_avals, params):
|
||||
def check_map(ctx, prim, in_avals, params):
|
||||
typecheck_assert("call_jaxpr" in params,
|
||||
f"Map primitive {prim} missing 'call_jaxpr' parameter")
|
||||
call_jaxpr = params["call_jaxpr"]
|
||||
@ -2144,7 +2147,7 @@ def check_map(prim, in_avals, params):
|
||||
if in_axis is not None else aval
|
||||
for aval, in_axis in zip(in_avals, in_axes)]
|
||||
with extend_axis_env(params['axis_name'], axis_size, None):
|
||||
_check_jaxpr(call_jaxpr, mapped_avals)
|
||||
_check_jaxpr(ctx, call_jaxpr, mapped_avals)
|
||||
|
||||
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval
|
||||
|
@ -945,7 +945,7 @@ def _typecheck_xmap(
|
||||
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
|
||||
for a, a_in_axes in zip(in_avals, in_axes)]
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
core._check_jaxpr(call_jaxpr, mapped_in_avals)
|
||||
core._check_jaxpr(core.JaxprPpContext(), call_jaxpr, mapped_in_avals)
|
||||
|
||||
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
|
||||
|
@ -431,8 +431,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
|
||||
self.assertRaisesRegex(
|
||||
core.JaxprTypeError,
|
||||
r"Variable '.' inconsistently typed as ShapedArray(.*), "
|
||||
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:i32\[\] = sin .",
|
||||
r"Variable 'b' inconsistently typed as f32\[\], "
|
||||
r"bound as i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
|
||||
lambda: core.check_jaxpr(jaxpr))
|
||||
|
||||
jaxpr = new_jaxpr()
|
||||
@ -440,8 +440,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
np.ones((2, 3), dtype=jnp.float32))
|
||||
self.assertRaisesRegex(
|
||||
core.JaxprTypeError,
|
||||
r"Variable '.' inconsistently typed as ShapedArray(.*), "
|
||||
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:f32\[2,3\] = sin .",
|
||||
r"Variable 'b' inconsistently typed as f32\[\], "
|
||||
r"bound as f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
|
||||
lambda: core.check_jaxpr(jaxpr))
|
||||
|
||||
def test_jaxpr_dropvar_from_jit_call(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user