Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages

PiperOrigin-RevId: 525561905
This commit is contained in:
Yash Katariya 2023-04-19 15:08:21 -07:00 committed by jax authors
parent 968dbaf8f3
commit 53e6382f4a
6 changed files with 85 additions and 69 deletions

View File

@ -643,7 +643,7 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
result_paths = trace_debug.result_paths() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, result_paths)
trace_debug.arg_names, tuple(result_paths))
return jaxpr.replace(debug_info=debug_info)
def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],

View File

@ -898,7 +898,8 @@ def lower_parallel_callable(
shards=shards, tuple_args=tuple_args,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=keepalive, host_callbacks=host_callbacks)
keepalive=keepalive, host_callbacks=host_callbacks,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
class PmapComputation(stages.XlaLowering):
@ -951,6 +952,34 @@ class UnloadedPmapExecutable:
ordered_effects: List[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
jaxpr_debug_info: core.JaxprDebugInfo
def build_execute_fun(self):
input_indices = []
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
assert isinstance(spec, sharding_impls.PmapSharding), spec
assert isinstance(aval, core.ShapedArray), aval
input_indices.append(
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
if spec.sharding_spec is not None else None)
handle_outs = local_avals_to_results_handler(self.local_output_avals,
self.output_shardings)
handle_args = InputsHandler(self.compiled.local_devices(),
self.input_shardings, input_indices)
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
self.backend, handle_args, handle_outs,
self.unordered_effects,
self.ordered_effects, self.keepalive,
bool(self.host_callbacks),
set(range(len(input_indices))))
return execute_fun
def load(self) -> PmapExecutable:
fingerprint = getattr(self.compiled, "fingerprint", None)
return PmapExecutable(
self.compiled, self.build_execute_fun, fingerprint,
self.local_input_avals, self.jaxpr_debug_info, self)
@staticmethod
def from_hlo(xla_computation,
@ -962,6 +991,7 @@ class UnloadedPmapExecutable:
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
compiler_options=None):
devices = pci.devices
if devices is None:
@ -1056,7 +1086,7 @@ class UnloadedPmapExecutable:
return _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, bool(unordered_effects),
ordered_effects)
ordered_effects, jaxpr_debug_info)
with dispatch.log_elapsed_time(
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec",
@ -1075,38 +1105,13 @@ class UnloadedPmapExecutable:
ordered_effects=ordered_effects,
keepalive=keepalive,
host_callbacks=host_callbacks,
).load()
def build_execute_fun(self):
input_indices = []
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
assert isinstance(spec, sharding_impls.PmapSharding), spec
assert isinstance(aval, core.ShapedArray), aval
input_indices.append(
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
if spec.sharding_spec is not None else None)
handle_outs = local_avals_to_results_handler(self.local_output_avals,
self.output_shardings)
handle_args = InputsHandler(self.compiled.local_devices(),
self.input_shardings, input_indices)
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
self.backend, handle_args, handle_outs,
self.unordered_effects,
self.ordered_effects, self.keepalive,
bool(self.host_callbacks),
set(range(len(input_indices))))
return execute_fun
def load(self) -> PmapExecutable:
fingerprint = getattr(self.compiled, "fingerprint", None)
return PmapExecutable(self.compiled, self.build_execute_fun, fingerprint,
self.local_input_avals, self)
jaxpr_debug_info=jaxpr_debug_info).load()
def _compile_replicated_pmap_executable_from_hlo(
xla_computation, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects):
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
jaxpr_debug_info):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=xla_computation,
@ -1116,20 +1121,23 @@ def _compile_replicated_pmap_executable_from_hlo(
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, lambda: execute_fun, None, pci.avals, None)
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
jaxpr_debug_info, None)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_unloaded_executable"]
"fingerprint", "in_avals", "_jaxpr_debug_info",
"_unloaded_executable"]
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
in_avals, unloaded_executable):
in_avals, jaxpr_debug_info, unloaded_executable):
self.xla_executable = xla_executable
self._unsafe_call = None
self.build_unsafe_call = build_unsafe_call
self.fingerprint = fingerprint
self.in_avals = in_avals
self._jaxpr_debug_info = jaxpr_debug_info
self._unloaded_executable = unloaded_executable
@property
@ -1147,7 +1155,7 @@ class PmapExecutable(stages.XlaExecutable):
def call(self, *args):
# TODO(frostig): do we need to check sharding and sharded avals?
arg_avals = map(xla.abstractify, args)
check_arg_avals_for_call(self.in_avals, arg_avals)
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable
@ -2655,7 +2663,7 @@ class UnloadedMeshExecutable:
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
compile_options, tuple(host_callbacks), bool(unordered_effects),
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
pmap_nreps)
pmap_nreps, jaxpr_debug_info)
if auto_spmd_lowering:
assert mesh is not None
@ -2794,7 +2802,7 @@ class MeshExecutable(stages.XlaExecutable):
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
arg_avals = map(xla.abstractify, kept_args)
ref_avals = self.in_avals
check_arg_avals_for_call(ref_avals, arg_avals)
check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings,
self._jaxpr_debug_info)
@ -2832,18 +2840,28 @@ class MeshExecutable(stages.XlaExecutable):
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
def check_arg_avals_for_call(ref_avals, arg_avals):
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
arg_names = ([''] * len(ref_avals) if jaxpr_debug_info is None else
jaxpr_debug_info.arg_names)
errors = []
num_errors = 5
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
if not core.typematch(ref_aval, arg_aval):
raise TypeError(
errors.append(f"Compiled with {ref_aval} and called with {arg_aval} for "
f"arg {name}")
if errors:
str_errors = '\n'.join(errors[:num_errors])
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise TypeError(
"Computation was compiled for different input types and called with "
"different types. One of the mismatches is:\n"
f"Compiled with:\n {ref_aval}\n"
f"called with:\n {arg_aval}")
f"different types. Here are {num_mismatch_str}:\n{str_errors}")
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
@ -2905,7 +2923,7 @@ def _compile_replicated_mesh_executable_from_hlo(
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
semantics_out_shardings, auto_spmd_lowering, compile_options,
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
backend, da, committed, pmap_nreps):
backend, da, committed, pmap_nreps, jaxpr_debug_info):
assert not auto_spmd_lowering
assert isinstance(da, _DeviceAssignment)
in_shardings = semantics_in_shardings.shardings
@ -2930,7 +2948,7 @@ def _compile_replicated_mesh_executable_from_hlo(
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, None)
kept_var_idx, jaxpr_debug_info, None)
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
@ -2998,12 +3016,13 @@ def check_gda_or_array_xla_sharding_match(
f"arg {name} with shape: {arg.aval.str_short()}")
if errors:
str_errors = '\n'.join(errors)
num_mismatch_str = (f'{len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
str_errors = '\n'.join(errors[:num_errors])
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise ValueError(
"Array(s) sharding does not match the input(s) sharding. "
f"Here are the {num_mismatch_str}:\n{str_errors}")
f"Here are {num_mismatch_str}:\n{str_errors}")
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:

View File

@ -992,10 +992,9 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f_exe = self.jit(f).lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
"Computation was compiled for different input types and called with "
"different types. One of the mismatches is:\n"
"Compiled with:\n.*float32.*\n"
"called with:\n.*int32.*",
r"Computation was compiled for different input types and called with "
r"different types. Here are the 1 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*for arg x",
lambda: f_exe(x_i32))
def test_jit_lower_compile_multi_arg(self):

View File

@ -930,13 +930,13 @@ class PJitTest(jtu.BufferDonationTestCase):
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
exe = f.lower(x_f32, x_f32).compile()
self.assertRaisesRegex(
with self.assertRaisesRegex(
TypeError,
"Computation was compiled for different input types and called with "
"different types. One of the mismatches is:\n"
"Compiled with:\n.*float32.*\n"
"called with:\n.*int32.*",
lambda: exe(x_i32, x_i32))
r"Computation was compiled for different input types and called with "
r"different types. Here are the 2 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*for arg x\n"
r"Compiled with.*float32.*and called with.*int32.*for arg y"):
exe(x_i32, x_i32)
@jtu.with_mesh([('x', 2), ('y', 2)])
def testLowerAsText(self):
@ -1541,7 +1541,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
r"Array\(s\) sharding does not match the input\(s\) sharding. "
"Here are the 5 mismatches out of 6"):
"Here are 5 mismatches out of 6"):
compiled(a2, a2, a2, a2, a2, a2)
with global_mesh:

View File

@ -240,10 +240,9 @@ class PythonPmapTest(jtu.JaxTestCase):
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
"Computation was compiled for different input types and called with "
"different types. One of the mismatches is:\n"
"Compiled with:\n.*float32.*\n"
"called with:\n.*int32.*",
r"Computation was compiled for different input types and called with "
r"different types. Here are the 1 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*for arg x",
lambda: f_exe(x_i32))
def testLowerCompileMultiArg(self):

View File

@ -707,11 +707,10 @@ class XMapTest(XMapTestCase):
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
"Computation was compiled for different input types and called with "
"different types. One of the mismatches is:\n"
"Compiled with:\n.*float32.*\n"
"called with:\n.*int32.*",
lambda: f_exe(x_i32))
r"Computation was compiled for different input types and called with "
r"different types. Here are the 1 mismatches:\n"
r"Compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
def testLowerAsText(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])