fix jaxpr type checking error messages

The pretty-printing changes a few months ago defined variable names
based on the state in JaxprPpContext instances. But that meant incorrect
variable names could be printed in jaxpr type checking error messages.

This commit correctly threads through the context so as to provide
error messages with coherent variable names.
This commit is contained in:
Matthew Johnson 2022-01-08 22:10:18 -08:00
parent 933971dbff
commit 3548e023ec
3 changed files with 27 additions and 24 deletions

View File

@ -920,7 +920,7 @@ class AbstractValue:
raise NotImplementedError("must override")
def str_short(self, short_dtypes=False):
raise NotImplementedError("must override")
return str(self)
class Bot(AbstractValue): pass
@ -2022,34 +2022,38 @@ def check_jaxpr(jaxpr: Jaxpr):
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
otherwise.
"""
ctx = JaxprPpContext()
try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names
except: pass
try:
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
_check_jaxpr(ctx, jaxpr, [v.aval for v in jaxpr.invars])
except JaxprTypeError as e:
if len(e.args) == 2:
msg, eqn_idx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10,
JaxprPpContext()))
msg, eqnidx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx))
else:
msg, = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, JaxprPpContext()))
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, ctx))
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
def _check_jaxpr(ctx: 'JaxprPpContext', jaxpr: Jaxpr,
in_avals: Sequence[AbstractValue]) -> None:
def read(v: Atom) -> AbstractValue:
if isinstance(v, Literal):
return raise_to_shaped(get_aval(v.val))
else:
typecheck_assert(v in env, f"Variable '{v}' not defined")
typecheck_assert(v in env, f"Variable '{pp_var(v, ctx)}' not defined")
return env[v]
def write(v: Var, a: AbstractValue) -> None:
typecheck_assert(v not in env, f"Variable '{v}' already bound")
typecheck_assert(v not in env, f"Variable '{pp_var(v, ctx)}' already bound")
if not isinstance(v, DropVar):
typecheck_assert(typecompat(v.aval, a),
f"Variable '{v}' inconsistently typed as {a}, "
f"bound as {v.aval}")
f"Variable '{pp_var(v, ctx)}' inconsistently typed as "
f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}")
env[v] = a
env : Dict[Var, AbstractValue] = {}
@ -2069,17 +2073,16 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
if out_avals is None:
out_avals = [v.aval for v in eqn.outvars]
elif prim.call_primitive:
out_avals = check_call(prim, in_avals, eqn.params)
out_avals = check_call(ctx, prim, in_avals, eqn.params)
elif prim.map_primitive:
out_avals = check_map(prim, in_avals, eqn.params)
out_avals = check_map(ctx, prim, in_avals, eqn.params)
else:
out_avals = check_eqn(prim, in_avals, eqn.params)
map(write, eqn.outvars, out_avals)
except JaxprTypeError as e:
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:",
str(pp.nest(2, pp_eqn(eqn, JaxprPpContext()))),
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx))),
f"from source: {src}"])
raise JaxprTypeError(msg, eqn_idx) from None
@ -2094,7 +2097,7 @@ def check_eqn(prim, in_avals, params):
out_avals = [out_avals]
return out_avals
def check_call(prim, in_avals, params):
def check_call(ctx, prim, in_avals, params):
typecheck_assert("call_jaxpr" in params,
f"Call primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
@ -2110,12 +2113,12 @@ def check_call(prim, in_avals, params):
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
_check_jaxpr(call_jaxpr, in_avals)
_check_jaxpr(ctx, call_jaxpr, in_avals)
out_avals = [v.aval for v in call_jaxpr.outvars]
return out_avals
def check_map(prim, in_avals, params):
def check_map(ctx, prim, in_avals, params):
typecheck_assert("call_jaxpr" in params,
f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
@ -2144,7 +2147,7 @@ def check_map(prim, in_avals, params):
if in_axis is not None else aval
for aval, in_axis in zip(in_avals, in_axes)]
with extend_axis_env(params['axis_name'], axis_size, None):
_check_jaxpr(call_jaxpr, mapped_avals)
_check_jaxpr(ctx, call_jaxpr, mapped_avals)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval

View File

@ -945,7 +945,7 @@ def _typecheck_xmap(
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
for a, a_in_axes in zip(in_avals, in_axes)]
with core.extend_axis_env_nd(global_axis_sizes.items()):
core._check_jaxpr(call_jaxpr, mapped_in_avals)
core._check_jaxpr(core.JaxprPpContext(), call_jaxpr, mapped_in_avals)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)

View File

@ -431,8 +431,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.' inconsistently typed as ShapedArray(.*), "
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:i32\[\] = sin .",
r"Variable 'b' inconsistently typed as f32\[\], "
r"bound as i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
lambda: core.check_jaxpr(jaxpr))
jaxpr = new_jaxpr()
@ -440,8 +440,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
np.ones((2, 3), dtype=jnp.float32))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.' inconsistently typed as ShapedArray(.*), "
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:f32\[2,3\] = sin .",
r"Variable 'b' inconsistently typed as f32\[\], "
r"bound as f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_jit_call(self):