From 53e6382f4a167317840246069a84213acfc66df7 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Apr 2023 15:08:21 -0700 Subject: [PATCH] Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages PiperOrigin-RevId: 525561905 --- jax/_src/api_util.py | 2 +- jax/_src/interpreters/pxla.py | 115 ++++++++++++++++++++-------------- tests/api_test.py | 7 +-- tests/pjit_test.py | 14 ++--- tests/pmap_test.py | 7 +-- tests/xmap_test.py | 9 ++- 6 files changed, 85 insertions(+), 69 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 375bf3922..dbdba2db5 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -643,7 +643,7 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo], result_paths = trace_debug.result_paths() # type: ignore debug_info = core.JaxprDebugInfo( trace_debug.traced_for, trace_debug.func_src_info, - trace_debug.arg_names, result_paths) + trace_debug.arg_names, tuple(result_paths)) return jaxpr.replace(debug_info=debug_info) def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo], diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 28efdd8b6..b56343f57 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -898,7 +898,8 @@ def lower_parallel_callable( shards=shards, tuple_args=tuple_args, unordered_effects=unordered_effects, ordered_effects=ordered_effects, - keepalive=keepalive, host_callbacks=host_callbacks) + keepalive=keepalive, host_callbacks=host_callbacks, + jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) class PmapComputation(stages.XlaLowering): @@ -951,6 +952,34 @@ class UnloadedPmapExecutable: ordered_effects: List[core.Effect] keepalive: Sequence[Any] host_callbacks: Sequence[Any] + jaxpr_debug_info: core.JaxprDebugInfo + + def build_execute_fun(self): + input_indices = [] + for aval, spec in safe_zip(self.local_input_avals, self.input_shardings): + assert isinstance(spec, sharding_impls.PmapSharding), spec + assert isinstance(aval, core.ShapedArray), aval + input_indices.append( + sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec) + if spec.sharding_spec is not None else None) + handle_outs = local_avals_to_results_handler(self.local_output_avals, + self.output_shardings) + handle_args = InputsHandler(self.compiled.local_devices(), + self.input_shardings, input_indices) + execute_fun = ExecuteReplicated(self.compiled, "parallel computation", + self.backend, handle_args, handle_outs, + self.unordered_effects, + self.ordered_effects, self.keepalive, + bool(self.host_callbacks), + set(range(len(input_indices)))) + return execute_fun + + def load(self) -> PmapExecutable: + fingerprint = getattr(self.compiled, "fingerprint", None) + + return PmapExecutable( + self.compiled, self.build_execute_fun, fingerprint, + self.local_input_avals, self.jaxpr_debug_info, self) @staticmethod def from_hlo(xla_computation, @@ -962,6 +991,7 @@ class UnloadedPmapExecutable: ordered_effects: List[core.Effect], host_callbacks: List[Any], keepalive: Any, + jaxpr_debug_info: core.JaxprDebugInfo, compiler_options=None): devices = pci.devices if devices is None: @@ -1056,7 +1086,7 @@ class UnloadedPmapExecutable: return _compile_replicated_pmap_executable_from_hlo( xla_computation, pci, input_indices, in_shardings, handle_outs, compile_options, host_callbacks, bool(unordered_effects), - ordered_effects) + ordered_effects, jaxpr_debug_info) with dispatch.log_elapsed_time( f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec", @@ -1075,38 +1105,13 @@ class UnloadedPmapExecutable: ordered_effects=ordered_effects, keepalive=keepalive, host_callbacks=host_callbacks, - ).load() - - def build_execute_fun(self): - input_indices = [] - for aval, spec in safe_zip(self.local_input_avals, self.input_shardings): - assert isinstance(spec, sharding_impls.PmapSharding), spec - assert isinstance(aval, core.ShapedArray), aval - input_indices.append( - sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec) - if spec.sharding_spec is not None else None) - handle_outs = local_avals_to_results_handler(self.local_output_avals, - self.output_shardings) - handle_args = InputsHandler(self.compiled.local_devices(), - self.input_shardings, input_indices) - execute_fun = ExecuteReplicated(self.compiled, "parallel computation", - self.backend, handle_args, handle_outs, - self.unordered_effects, - self.ordered_effects, self.keepalive, - bool(self.host_callbacks), - set(range(len(input_indices)))) - return execute_fun - - def load(self) -> PmapExecutable: - fingerprint = getattr(self.compiled, "fingerprint", None) - - return PmapExecutable(self.compiled, self.build_execute_fun, fingerprint, - self.local_input_avals, self) + jaxpr_debug_info=jaxpr_debug_info).load() def _compile_replicated_pmap_executable_from_hlo( xla_computation, pci, input_indices, in_shardings, handle_outs, - compile_options, host_callbacks, has_unordered_effects, ordered_effects): + compile_options, host_callbacks, has_unordered_effects, ordered_effects, + jaxpr_debug_info): # Use the standard out_handler. execute_fun = pci.backend.compile_replicated( is_trivial=False, name=pci.name, computation=xla_computation, @@ -1116,20 +1121,23 @@ def _compile_replicated_pmap_executable_from_hlo( in_indices=input_indices, in_shardings=in_shardings, kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs) # TODO(frostig): need `compile_replicated` to give us the XLA executable - return PmapExecutable(None, lambda: execute_fun, None, pci.avals, None) + return PmapExecutable(None, lambda: execute_fun, None, pci.avals, + jaxpr_debug_info, None) class PmapExecutable(stages.XlaExecutable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", - "fingerprint", "in_avals", "_unloaded_executable"] + "fingerprint", "in_avals", "_jaxpr_debug_info", + "_unloaded_executable"] def __init__(self, xla_executable, build_unsafe_call, fingerprint, - in_avals, unloaded_executable): + in_avals, jaxpr_debug_info, unloaded_executable): self.xla_executable = xla_executable self._unsafe_call = None self.build_unsafe_call = build_unsafe_call self.fingerprint = fingerprint self.in_avals = in_avals + self._jaxpr_debug_info = jaxpr_debug_info self._unloaded_executable = unloaded_executable @property @@ -1147,7 +1155,7 @@ class PmapExecutable(stages.XlaExecutable): def call(self, *args): # TODO(frostig): do we need to check sharding and sharded avals? arg_avals = map(xla.abstractify, args) - check_arg_avals_for_call(self.in_avals, arg_avals) + check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info) return self.unsafe_call(*args) # pylint: disable=not-callable @@ -2655,7 +2663,7 @@ class UnloadedMeshExecutable: semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering, compile_options, tuple(host_callbacks), bool(unordered_effects), tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed, - pmap_nreps) + pmap_nreps, jaxpr_debug_info) if auto_spmd_lowering: assert mesh is not None @@ -2794,7 +2802,7 @@ class MeshExecutable(stages.XlaExecutable): kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx] arg_avals = map(xla.abstractify, kept_args) ref_avals = self.in_avals - check_arg_avals_for_call(ref_avals, arg_avals) + check_arg_avals_for_call(ref_avals, arg_avals, self._jaxpr_debug_info) # Check the GDA sharding and the input sharding. check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings, self._jaxpr_debug_info) @@ -2832,18 +2840,28 @@ class MeshExecutable(stages.XlaExecutable): return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore -def check_arg_avals_for_call(ref_avals, arg_avals): +def check_arg_avals_for_call(ref_avals, arg_avals, + jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None): if len(ref_avals) != len(arg_avals): raise TypeError( f"Computation compiled for {len(ref_avals)} inputs " f"but called with {len(arg_avals)}") - for ref_aval, arg_aval in zip(ref_avals, arg_avals): + arg_names = ([''] * len(ref_avals) if jaxpr_debug_info is None else + jaxpr_debug_info.arg_names) + errors = [] + num_errors = 5 + for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names): if not core.typematch(ref_aval, arg_aval): - raise TypeError( + errors.append(f"Compiled with {ref_aval} and called with {arg_aval} for " + f"arg {name}") + if errors: + str_errors = '\n'.join(errors[:num_errors]) + num_mismatch_str = ( + f'the {len(errors)} mismatches' if len(errors) < num_errors else + f"{num_errors} mismatches out of {len(errors)}") + raise TypeError( "Computation was compiled for different input types and called with " - "different types. One of the mismatches is:\n" - f"Compiled with:\n {ref_aval}\n" - f"called with:\n {arg_aval}") + f"different types. Here are {num_mismatch_str}:\n{str_errors}") def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): @@ -2905,7 +2923,7 @@ def _compile_replicated_mesh_executable_from_hlo( computation, name, global_in_avals, global_out_avals, semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering, compile_options, host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx, - backend, da, committed, pmap_nreps): + backend, da, committed, pmap_nreps, jaxpr_debug_info): assert not auto_spmd_lowering assert isinstance(da, _DeviceAssignment) in_shardings = semantics_in_shardings.shardings @@ -2930,7 +2948,7 @@ def _compile_replicated_mesh_executable_from_hlo( xla_executable = None return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, auto_spmd_lowering, - kept_var_idx, None) + kept_var_idx, jaxpr_debug_info, None) def _compile_replicated_mesh_executable_from_trivial_jaxpr( @@ -2998,12 +3016,13 @@ def check_gda_or_array_xla_sharding_match( f"arg {name} with shape: {arg.aval.str_short()}") if errors: - str_errors = '\n'.join(errors) - num_mismatch_str = (f'{len(errors)} mismatches' if len(errors) < num_errors else - f"{num_errors} mismatches out of {len(errors)}") + str_errors = '\n'.join(errors[:num_errors]) + num_mismatch_str = ( + f'the {len(errors)} mismatches' if len(errors) < num_errors else + f"{num_errors} mismatches out of {len(errors)}") raise ValueError( "Array(s) sharding does not match the input(s) sharding. " - f"Here are the {num_mismatch_str}:\n{str_errors}") + f"Here are {num_mismatch_str}:\n{str_errors}") def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: diff --git a/tests/api_test.py b/tests/api_test.py index 878e7d3ca..d550c07c8 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -992,10 +992,9 @@ class CPPJitTest(jtu.BufferDonationTestCase): f_exe = self.jit(f).lower(x_f32).compile() self.assertRaisesRegex( TypeError, - "Computation was compiled for different input types and called with " - "different types. One of the mismatches is:\n" - "Compiled with:\n.*float32.*\n" - "called with:\n.*int32.*", + r"Computation was compiled for different input types and called with " + r"different types. Here are the 1 mismatches:\n" + r"Compiled with.*float32.*and called with.*int32.*for arg x", lambda: f_exe(x_i32)) def test_jit_lower_compile_multi_arg(self): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 4e7a84eae..50eb2b3c8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -930,13 +930,13 @@ class PJitTest(jtu.BufferDonationTestCase): x_f32 = x.astype(jnp.float32) x_i32 = x.astype(jnp.int32) exe = f.lower(x_f32, x_f32).compile() - self.assertRaisesRegex( + with self.assertRaisesRegex( TypeError, - "Computation was compiled for different input types and called with " - "different types. One of the mismatches is:\n" - "Compiled with:\n.*float32.*\n" - "called with:\n.*int32.*", - lambda: exe(x_i32, x_i32)) + r"Computation was compiled for different input types and called with " + r"different types. Here are the 2 mismatches:\n" + r"Compiled with.*float32.*and called with.*int32.*for arg x\n" + r"Compiled with.*float32.*and called with.*int32.*for arg y"): + exe(x_i32, x_i32) @jtu.with_mesh([('x', 2), ('y', 2)]) def testLowerAsText(self): @@ -1541,7 +1541,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, r"Array\(s\) sharding does not match the input\(s\) sharding. " - "Here are the 5 mismatches out of 6"): + "Here are 5 mismatches out of 6"): compiled(a2, a2, a2, a2, a2, a2) with global_mesh: diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 8bd77713c..ad314dc3d 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -240,10 +240,9 @@ class PythonPmapTest(jtu.JaxTestCase): f_exe = f.lower(x_f32).compile() self.assertRaisesRegex( TypeError, - "Computation was compiled for different input types and called with " - "different types. One of the mismatches is:\n" - "Compiled with:\n.*float32.*\n" - "called with:\n.*int32.*", + r"Computation was compiled for different input types and called with " + r"different types. Here are the 1 mismatches:\n" + r"Compiled with.*float32.*and called with.*int32.*for arg x", lambda: f_exe(x_i32)) def testLowerCompileMultiArg(self): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 9b6a62812..4c2fc79e8 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -707,11 +707,10 @@ class XMapTest(XMapTestCase): f_exe = f.lower(x_f32).compile() self.assertRaisesRegex( TypeError, - "Computation was compiled for different input types and called with " - "different types. One of the mismatches is:\n" - "Compiled with:\n.*float32.*\n" - "called with:\n.*int32.*", - lambda: f_exe(x_i32)) + r"Computation was compiled for different input types and called with " + r"different types. Here are the 1 mismatches:\n" + r"Compiled with.*float32.*and called with.*int32.*", + lambda: f_exe(x_i32)) def testLowerAsText(self): f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])