Merge pull request #9143 from mattjj:fix-jaxpr-checking-error-messages

PiperOrigin-RevId: 420866862
This commit is contained in:
jax authors 2022-01-10 15:09:04 -08:00
commit 67723da38b
3 changed files with 27 additions and 24 deletions

View File

@ -884,7 +884,7 @@ class AbstractValue:
raise NotImplementedError("must override")
def str_short(self, short_dtypes=False):
raise NotImplementedError("must override")
return str(self)
class Bot(AbstractValue): pass
@ -1999,34 +1999,38 @@ def check_jaxpr(jaxpr: Jaxpr):
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
otherwise.
"""
ctx = JaxprPpContext()
try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names
except: pass
try:
_check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
_check_jaxpr(ctx, jaxpr, [v.aval for v in jaxpr.invars])
except JaxprTypeError as e:
if len(e.args) == 2:
msg, eqn_idx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqn_idx - 10, eqn_idx + 10,
JaxprPpContext()))
msg, eqnidx = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx))
else:
msg, = e.args
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, JaxprPpContext()))
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, 0, 20, ctx))
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
raise JaxprTypeError(msg) from None
def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
def _check_jaxpr(ctx: 'JaxprPpContext', jaxpr: Jaxpr,
in_avals: Sequence[AbstractValue]) -> None:
def read(v: Atom) -> AbstractValue:
if isinstance(v, Literal):
return raise_to_shaped(get_aval(v.val))
else:
typecheck_assert(v in env, f"Variable '{v}' not defined")
typecheck_assert(v in env, f"Variable '{pp_var(v, ctx)}' not defined")
return env[v]
def write(v: Var, a: AbstractValue) -> None:
typecheck_assert(v not in env, f"Variable '{v}' already bound")
typecheck_assert(v not in env, f"Variable '{pp_var(v, ctx)}' already bound")
if not isinstance(v, DropVar):
typecheck_assert(typecompat(v.aval, a),
f"Variable '{v}' inconsistently typed as {a}, "
f"bound as {v.aval}")
f"Variable '{pp_var(v, ctx)}' inconsistently typed as "
f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}")
env[v] = a
env : Dict[Var, AbstractValue] = {}
@ -2046,17 +2050,16 @@ def _check_jaxpr(jaxpr: Jaxpr, in_avals: Sequence[AbstractValue]):
if out_avals is None:
out_avals = [v.aval for v in eqn.outvars]
elif prim.call_primitive:
out_avals = check_call(prim, in_avals, eqn.params)
out_avals = check_call(ctx, prim, in_avals, eqn.params)
elif prim.map_primitive:
out_avals = check_map(prim, in_avals, eqn.params)
out_avals = check_map(ctx, prim, in_avals, eqn.params)
else:
out_avals = check_eqn(prim, in_avals, eqn.params)
map(write, eqn.outvars, out_avals)
except JaxprTypeError as e:
msg, = e.args
src = source_info_util.summarize(eqn.source_info)
msg = "\n\n".join([msg, "in equation:",
str(pp.nest(2, pp_eqn(eqn, JaxprPpContext()))),
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx))),
f"from source: {src}"])
raise JaxprTypeError(msg, eqn_idx) from None
@ -2071,7 +2074,7 @@ def check_eqn(prim, in_avals, params):
out_avals = [out_avals]
return out_avals
def check_call(prim, in_avals, params):
def check_call(ctx, prim, in_avals, params):
typecheck_assert("call_jaxpr" in params,
f"Call primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
@ -2087,12 +2090,12 @@ def check_call(prim, in_avals, params):
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")
_check_jaxpr(call_jaxpr, in_avals)
_check_jaxpr(ctx, call_jaxpr, in_avals)
out_avals = [v.aval for v in call_jaxpr.outvars]
return out_avals
def check_map(prim, in_avals, params):
def check_map(ctx, prim, in_avals, params):
typecheck_assert("call_jaxpr" in params,
f"Map primitive {prim} missing 'call_jaxpr' parameter")
call_jaxpr = params["call_jaxpr"]
@ -2121,7 +2124,7 @@ def check_map(prim, in_avals, params):
if in_axis is not None else aval
for aval, in_axis in zip(in_avals, in_axes)]
with extend_axis_env(params['axis_name'], axis_size, None):
_check_jaxpr(call_jaxpr, mapped_avals)
_check_jaxpr(ctx, call_jaxpr, mapped_avals)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval

View File

@ -958,7 +958,7 @@ def _typecheck_xmap(
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
for a, a_in_axes in zip(in_avals, in_axes)]
with core.extend_axis_env_nd(global_axis_sizes.items()):
core._check_jaxpr(call_jaxpr, mapped_in_avals)
core._check_jaxpr(core.JaxprPpContext(), call_jaxpr, mapped_in_avals)
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)

View File

@ -431,8 +431,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.' inconsistently typed as ShapedArray(.*), "
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:i32\[\] = sin .",
r"Variable 'b' inconsistently typed as f32\[\], "
r"bound as i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
lambda: core.check_jaxpr(jaxpr))
jaxpr = new_jaxpr()
@ -440,8 +440,8 @@ class JaxprTypeChecks(jtu.JaxTestCase):
np.ones((2, 3), dtype=jnp.float32))
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.' inconsistently typed as ShapedArray(.*), "
r"bound as ShapedArray(.*)\n\nin equation:\n\n.:f32\[2,3\] = sin .",
r"Variable 'b' inconsistently typed as f32\[\], "
r"bound as f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_jit_call(self):