From 0a19638490b2874f24e54fb898a23397f84f7059 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Apr 2023 12:35:15 -0700 Subject: [PATCH] Plumb debug_info to meshExecutable as a optional arg to raise better error messages. PiperOrigin-RevId: 525521694 --- jax/_src/interpreters/pxla.py | 46 +++++++++++++++++++-------- jax/_src/pjit.py | 3 +- jax/experimental/jax2tf/jax_export.py | 6 ++-- tests/pjit_test.py | 28 ++++++++++++---- 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a529dda49..28efdd8b6 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2110,7 +2110,8 @@ def lower_sharding_computation( backend=backend, device_assignment=da_object, committed=committed, - pmap_nreps=nreps) + pmap_nreps=nreps, + jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) def _to_logical_sharding( @@ -2291,7 +2292,8 @@ def lower_mesh_computation( kept_var_idx=set(range(len(global_in_avals))), backend=backend, device_assignment=_create_da_object(tuple(mesh.devices.flat)), - committed=True) + committed=True, + jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) class MeshComputation(stages.XlaLowering): _hlo: Optional[ir.Module] @@ -2580,6 +2582,7 @@ class UnloadedMeshExecutable: host_callbacks: Sequence[Any] kept_var_idx: Set[int] auto_spmd_lowering: bool + jaxpr_debug_info: Optional[core.JaxprDebugInfo] def build_unsafe_call(self): input_indices = _get_input_indices(self.input_avals, self.input_shardings, @@ -2601,7 +2604,7 @@ class UnloadedMeshExecutable: self.input_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, - self) + self.jaxpr_debug_info, self) # May return a MeshExecutable in the compile_replicated case. @staticmethod @@ -2627,6 +2630,7 @@ class UnloadedMeshExecutable: device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]], committed: bool, pmap_nreps: int = 1, + jaxpr_debug_info: Optional[core.JaxprDebugInfo] = None, compiler_options=None ) -> MeshExecutable: compiler_options_keys = tuple( @@ -2714,7 +2718,8 @@ class UnloadedMeshExecutable: keepalive=keepalive, host_callbacks=host_callbacks, kept_var_idx=kept_var_idx, - auto_spmd_lowering=auto_spmd_lowering).load() + auto_spmd_lowering=auto_spmd_lowering, + jaxpr_debug_info=jaxpr_debug_info).load() class MeshExecutableFastpathData(NamedTuple): @@ -2731,12 +2736,12 @@ class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", - "_unloaded_executable", + "_jaxpr_debug_info", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - unloaded_executable=None): + jaxpr_debug_info=None, unloaded_executable=None): self.xla_executable = xla_executable self.build_unsafe_call = build_unsafe_call # in_avals is a list of global and local avals. Aval is global if input @@ -2747,6 +2752,7 @@ class MeshExecutable(stages.XlaExecutable): self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx + self._jaxpr_debug_info = jaxpr_debug_info self._unloaded_executable = unloaded_executable @property @@ -2790,7 +2796,8 @@ class MeshExecutable(stages.XlaExecutable): ref_avals = self.in_avals check_arg_avals_for_call(ref_avals, arg_avals) # Check the GDA sharding and the input sharding. - check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings) + check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings, + self._jaxpr_debug_info) return self.unsafe_call(*args) # pylint: disable=not-callable def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: @@ -2968,10 +2975,14 @@ def check_device_backend_on_shardings(shardings) -> bool: def check_gda_or_array_xla_sharding_match( - args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> None: + args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding], + jaxpr_debug_info: Optional[core.JaxprDebugInfo]) -> None: from jax._src.array import ArrayImpl - - for arg, xs in safe_zip(args, in_xla_shardings): + arg_names = ([''] * len(args) if jaxpr_debug_info is None else + jaxpr_debug_info.arg_names) + errors = [] + num_errors = 5 + for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names): if not isinstance(arg, ArrayImpl): continue @@ -2982,10 +2993,17 @@ def check_gda_or_array_xla_sharding_match( not op_shardings.are_op_shardings_equal( arg.sharding._to_xla_op_sharding(arg.ndim), xs._to_xla_op_sharding(arg.ndim))): - raise ValueError( - f"Array sharding does not match the input sharding. " - f"Got Array sharding: {arg.sharding} and xla sharding: {xs} for " - f"arg shape: {arg.shape}, arg value: {arg}") + errors.append( + f"Got Array sharding: {arg.sharding} and input sharding: {xs} for " + f"arg {name} with shape: {arg.aval.str_short()}") + + if errors: + str_errors = '\n'.join(errors) + num_mismatch_str = (f'{len(errors)} mismatches' if len(errors) < num_errors else + f"{num_errors} mismatches out of {len(errors)}") + raise ValueError( + "Array(s) sharding does not match the input(s) sharding. " + f"Here are the {num_mismatch_str}:\n{str_errors}") def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ca241646d..927487cf5 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1076,7 +1076,8 @@ def _pjit_call_impl(*args, jaxpr, _most_recent_pjit_call_executable.value = compiled # This check is expensive so only do it if enable_checks is on. if compiled._auto_spmd_lowering and config.jax_enable_checks: - pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings) + pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings, + jaxpr.jaxpr.debug_info) if config.jax_distributed_debug: # Defensively only perform fingerprint logic if debug logging is enabled # NOTE(skyewm): I didn't benchmark this diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 42e2d1eda..4ef81d0e7 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -329,11 +329,13 @@ def _check_lowered(lowered) -> None: # Check that we do not see new compile_args. When we add a compile_args it is # safe to add it to the allowed_compile_args if it does not change the semantics # or the calling convention of the lowered module. - allowed_compile_args = ["backend", "mesh", "global_in_avals", + allowed_compile_args = [ + "backend", "mesh", "global_in_avals", "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", "spmd_lowering", "auto_spmd_lowering", "tuple_args", "ordered_effects", "unordered_effects", - "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment"] + "keepalive", "host_callbacks", "pmap_nreps", "committed", + "device_assignment", "jaxpr_debug_info"] for compile_arg in lowered.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e9d6bcb3f..4e7a84eae 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1229,7 +1229,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase): input_data) with self.assertRaisesRegex( ValueError, - "Array sharding does not match the input sharding."): + r"Array\(s\) sharding does not match the input\(s\) " + r"sharding.*\n.*for arg x"): compiled(arr) def test_gda_auto_shardings_len(self): @@ -1530,16 +1531,31 @@ class ArrayPjitTest(jtu.JaxTestCase): with global_mesh: f = pjit( - lambda x, y: x @ y.T, + lambda x, y, z, a, b, c: (x @ y.T, y, z, a, b, c), in_shardings=NamedSharding(global_mesh, P('x' ,'y'))) - compiled = f.lower(aval, aval).compile() - out = compiled(a1, a1) + compiled = f.lower(aval, aval, aval, aval, aval, aval).compile() + out, *_ = compiled(a1, a1, a1, a1, a1, a1) self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out._value, input_data @ input_data.T) with self.assertRaisesRegex( - ValueError, 'Array sharding does not match the input sharding'): - compiled(a2, a2) + ValueError, + r"Array\(s\) sharding does not match the input\(s\) sharding. " + "Here are the 5 mismatches out of 6"): + compiled(a2, a2, a2, a2, a2, a2) + + with global_mesh: + f = pjit(lambda a: a, in_shardings=NamedSharding(global_mesh, P('x' ,'y'))) + abstract_inp = {'x': aval, 'y': {'y1': aval}} + inp1 = {'x': a1, 'y': {'y1': a1}} + compiled = f.lower(abstract_inp).compile() + compiled(inp1) + inp2 = {'x': a2, 'y': {'y1': a2}} + with self.assertRaisesRegex( + ValueError, + r"Array\(s\) sharding does not match the input\(s\) sharding. " + "Here are the 2 mismatches"): + compiled(inp2) def test_globally_sharded_key_array_result_8x4_single_device(self): input_shape = (8, 4)