diff --git a/jax/_src/api.py b/jax/_src/api.py index 3625f13ac..8f9f29262 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -689,13 +689,19 @@ def _cpp_jit( inline=inline, keep_unused=keep_unused)) execute = None - if isinstance(top_trace, core.EvalTrace) and not ( - jax.config.jax_debug_nans or jax.config.jax_debug_infs): - execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params) - out_flat = call_bind_continuation(execute(*args_flat)) - else: - out_flat = call_bind_continuation( - top_trace.process_call(primitive, fun_, tracers, params)) + try: + if isinstance(top_trace, core.EvalTrace) and not ( + jax.config.jax_debug_nans or jax.config.jax_debug_infs): + execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params) + out_flat = call_bind_continuation(execute(*args_flat)) + else: + out_flat = call_bind_continuation( + top_trace.process_call(primitive, fun_, tracers, params)) + except pxla.DeviceAssignmentMismatchError as e: + fails, = e.args + msg = pjit._device_assignment_mismatch_error( + fun, fails, in_tree, args_flat, 'jit') + raise ValueError(msg) from None out_pytree_def = out_tree() out = tree_unflatten(out_pytree_def, out_flat) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 4692cb54b..c487fe443 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -45,6 +45,7 @@ from jax._src import dtypes from jax._src import linear_util as lu from jax._src import path from jax._src import profiler +from jax._src import source_info_util from jax._src import stages from jax._src import traceback_util from jax._src import util @@ -553,20 +554,22 @@ def jaxpr_has_primitive(jaxpr, prim_name: str): return False -def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]: +def jaxpr_shardings(jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, str]]: from jax.experimental import pjit, shard_map for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: - yield eqn.params['sharding'] + yield (eqn.params['sharding'], source_info_util.summarize(eqn.source_info)) elif eqn.primitive is pjit.pjit_p: - yield from eqn.params['in_shardings'] - yield from eqn.params['out_shardings'] + source_info = source_info_util.summarize(eqn.source_info) + yield from ((i, source_info) for i in eqn.params['in_shardings']) + yield from ((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: + source_info = source_info_util.summarize(eqn.source_info) def _names_to_pspec(names): ndmin = max(names) + 1 if names else 0 return PartitionSpec(*(names.get(i) for i in range(ndmin))) - yield from (NamedSharding(eqn.params['mesh'], _names_to_pspec(names)) + yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) for names in eqn.params['in_names']) for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_shardings(subjaxpr) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 61a6a95a1..7e9a1ca13 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2715,55 +2715,99 @@ def check_if_any_auto( return True return False +class MismatchType(enum.Enum): + ARG_SHARDING = 0 + OUT_SHARDING = 1 + SHARDING_INSIDE_COMPUTATION = 2 + CONTEXT_DEVICES = 3 + IN_SHARDING = 4 + + def __str__(self): + if self.name == 'IN_SHARDING': + return 'explicit input sharding' + elif self.name == 'OUT_SHARDING': + return 'explicit output sharding' + elif self.name == 'SHARDING_INSIDE_COMPUTATION': + return 'with_sharding_constraint or nested pjit or shard_map' + elif self.name == 'CONTEXT_DEVICES': + return 'devices' + return f'{self.name}' + + +@dataclasses.dataclass +class DeviceAssignmentMismatch: + da: Sequence[xc.Device] + m_type: MismatchType + source_info: Optional[str] + + @property + def device_ids(self) -> Sequence[int]: + return [d.id for d in self.da] + + @property + def platform(self) -> str: + return self.da[0].platform.upper() + + def _maybe_api_name(self, api_name) -> str: + return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else "" + + @property + def source_info_str(self): + return "" if self.source_info is None else f" at {self.source_info}" + + @property + def _dev_ids_plat_str(self): + return f"device ids {self.device_ids} on platform {self.platform}" + + def _str(self, api_name): + return (f"{self._maybe_api_name(api_name)} {self.m_type} with " + f"{self._dev_ids_plat_str}{self.source_info_str}") + + +class DeviceAssignmentMismatchError(Exception): + pass + + +ShardingInfo = Tuple[Union[sharding_internal.XLACompatibleSharding, + UnspecifiedValue, AUTOAxisResource], + MismatchType, Optional[str]] def _get_and_check_device_assignment( - shardings: Iterable[Union[sharding_internal.XLACompatibleSharding, - UnspecifiedValue, AUTOAxisResource]], - devices: Optional[Sequence[xc.Device]] + shardings: Iterable[ShardingInfo], + devices: Optional[Sequence[xc.Device]], ) -> Tuple[xla.Backend, Sequence[xc.Device]]: from jax._src.api import local_devices - first_device_assignment = None + first_sharding_info = None if devices is None: devices = [] else: devices = list(devices) - for i in shardings: + for i, s_type, source_info in shardings: if is_auto(i) or _is_unspecified(i): continue - # Assign `first_device_assignment` after `AUTO` and `UNSPECIFIED` have been + # Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been # skipped. - if first_device_assignment is None: - first_device_assignment = list(i._device_assignment) # type: ignore + if first_sharding_info is None: + first_sharding_info = (list(i._device_assignment), s_type, source_info) # type: ignore arr_device_assignment = list(i._device_assignment) # type: ignore if not devices: - if first_device_assignment != arr_device_assignment: - p1 = first_device_assignment[0].platform.upper() - fda_ids = [d.id for d in first_device_assignment] - a_ids = [d.id for d in arr_device_assignment] - p2 = arr_device_assignment[0].platform.upper() - raise ValueError( - "Devices of all `Array` inputs and outputs should be " - "the same. " - f"Got array device ids {fda_ids} on platform {p1} and " - f"another array's device ids {a_ids} on platform {p2}") + if first_sharding_info[0] != arr_device_assignment: + raise DeviceAssignmentMismatchError([ + DeviceAssignmentMismatch(*first_sharding_info), + DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) else: if devices != arr_device_assignment: - p1 = devices[0].platform.upper() - dev_ids = [d.id for d in devices] - a_ids = [d.id for d in arr_device_assignment] - p2 = arr_device_assignment[0].platform.upper() - raise ValueError( - "Pjit's devices and Array's devices should be equal. " - f"Got Pjit's device ids {dev_ids} on platform {p1} and " - f"Array's device ids {a_ids} on platform {p2}") - if first_device_assignment is None and devices: + raise DeviceAssignmentMismatchError([ + DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None), + DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + if first_sharding_info is None and devices: final_device_assignment = devices - elif first_device_assignment is None: + elif first_sharding_info is None: final_device_assignment = [config.jax_default_device or local_devices()[0]] else: - final_device_assignment = first_device_assignment + final_device_assignment = first_sharding_info[0] return xb.get_device_backend(final_device_assignment[0]), final_device_assignment @@ -2807,8 +2851,12 @@ def lower_sharding_computation( # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr)) - backend, device_assignment = _get_and_check_device_assignment(it.chain( - in_shardings, out_shardings, jaxpr_sharding), devices_from_context) # type: ignore + backend, device_assignment = _get_and_check_device_assignment( + it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings], + [(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], # type: ignore + [(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) # type: ignore + for js, source_info in jaxpr_sharding]), + devices_from_context) # TODO(yashkatariya): Make this logic work after DCE because there can be # equations inside the jaxpr that don't affect the output so whether the @@ -2817,7 +2865,7 @@ def lower_sharding_computation( devices_from_context or len(device_assignment) > 1 or any(not _is_unspecified(i) for i in in_shardings) or - any(not _is_unspecified(js) for js in jaxpr_sharding) or + any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore any(not _is_unspecified(o) for o in out_shardings)) # type: ignore in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d1300452d..7cdfc99e5 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -14,6 +14,7 @@ import dataclasses from enum import IntEnum +import inspect import numpy as np from collections import OrderedDict, Counter from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional, @@ -59,7 +60,7 @@ from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.traceback_util import api_boundary -from jax._src.tree_util import prefix_errors +from jax._src.tree_util import (prefix_errors, _generate_key_paths) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, tuple_insert, weakref_lru_cache, @@ -111,14 +112,87 @@ def _check_all_or_none_unspecified(axis_resources, name): '`pjit._UNSPECIFIED`.') return unspecified -def _python_pjit_helper(infer_params_fn, *args, **kwargs): - args_flat, _, params, _, out_tree, _ = infer_params_fn(*args, **kwargs) + +def _try_infer_args(f, tree): + dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) + try: + return inspect.signature(f).bind(*dummy_args) + except (TypeError, ValueError): + return None + + +def _find_arg_mismatch(arg_list, fails, fun_name): + first_err, second_err = fails + mismatched_args_msg = [] + for name, inp_da, aval in arg_list: + if first_err.m_type == pxla.MismatchType.ARG_SHARDING: + if first_err.da == inp_da: + mismatched_args_msg.append( + (f"argument {name} of {fun_name} with {aval.str_short()} and " + f"{first_err._dev_ids_plat_str}")) + break + + for name, inp_da, aval in arg_list: + if second_err.m_type == pxla.MismatchType.ARG_SHARDING: + if second_err.da == inp_da: + mismatched_args_msg.append( + (f"argument {name} of {fun_name} with {aval.str_short()} and " + f"{second_err._dev_ids_plat_str}")) + break + return mismatched_args_msg + + +def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name): + sig = _try_infer_args(fun, in_tree) + args = tree_unflatten(in_tree, args_flat) + args_aug = _generate_key_paths(args) + + arg_list = [] + for arg_key, val in args_aug: + ak, *rem_keys = arg_key.keys + if sig is not None: + loc = ''.join(k.pprint() for k in rem_keys) + arg_name = f'{list(sig.arguments.keys())[ak.key]}{loc}' + else: + arg_name = '' + da = val.sharding._device_assignment if hasattr(val, 'sharding') else None + arg_list.append((arg_name, da, shaped_abstractify(val))) + + fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) + mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) + + if len(mismatched_args_msg) == 2: + first, second = mismatched_args_msg # pylint: disable=unbalanced-tuple-unpacking + extra_msg = f" Got {first} and {second}" + elif len(mismatched_args_msg) == 1: + first, second = fails + # Choose the failure left which is not already covered by ARG_SHARDING. + left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first + extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" + else: + first, second = fails + extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}" + msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}") + return msg + + +def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): + args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( + *args, **kwargs) for arg in args_flat: dispatch.check_arg(arg) - out_flat = pjit_p.bind(*args_flat, **params) + try: + out_flat = pjit_p.bind(*args_flat, **params) + except pxla.DeviceAssignmentMismatchError as e: + fails, = e.args + api_name = 'jit' if params['resource_env'] is None else 'pjit' + msg = _device_assignment_mismatch_error( + fun, fails, in_tree, args_flat, api_name) + raise ValueError(msg) from None outs = tree_unflatten(out_tree, out_flat) return outs, out_flat, out_tree, args_flat + def _python_pjit(fun: Callable, infer_params_fn): @wraps(fun) @@ -126,7 +200,7 @@ def _python_pjit(fun: Callable, infer_params_fn): def wrapped(*args, **kwargs): if config.jax_disable_jit: return fun(*args, **kwargs) - return _python_pjit_helper(infer_params_fn, *args, **kwargs)[0] + return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0] return wrapped @@ -153,7 +227,7 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames, @api_boundary def cache_miss(*args, **kwargs): outs, out_flat, out_tree, args_flat = _python_pjit_helper( - infer_params_fn, *args, **kwargs) + fun, infer_params_fn, *args, **kwargs) executable = _read_most_recent_pjit_call_executable() @@ -1081,13 +1155,15 @@ def _resolve_in_shardings( if isinstance(arg_s, PmapSharding): continue if getattr(a, '_committed', True): - committed_arg_shardings.append(arg_s) + committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) # Check if the device_assignment across inputs, outputs and arguments is the # same. pxla._get_and_check_device_assignment( it.chain( - committed_arg_shardings, pjit_in_shardings, out_shardings), + committed_arg_shardings, + [(i, pxla.MismatchType.IN_SHARDING, None) for i in pjit_in_shardings], + [(o, pxla.MismatchType.OUT_SHARDING, None) for o in out_shardings]), (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat))) resolved_in_shardings = [] diff --git a/tests/api_test.py b/tests/api_test.py index 84b33e096..66ba5f39a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -766,7 +766,7 @@ class CPPJitTest(jtu.BufferDonationTestCase): return jax.jit(lambda x: x + 1, backend="cpu")(x) if jax.config.jax_jit_pjit_api_merge: - msg = 'Devices of all `Array` inputs and outputs should be the same' + msg = 'Received incompatible devices for jitted computation' else: msg = ("Outer-jit backend specification .* must match explicit inner-jit " "backend specification cpu.") diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 41403df72..0adbf7fb9 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -104,7 +104,7 @@ class MultiDeviceTest(jtu.JaxTestCase): # in an error if jax.config.jax_jit_pjit_api_merge: - err_msg = "Devices of all `Array` inputs and outputs should be the same" + err_msg = "Received incompatible devices for jitted computation" else: err_msg = "primitive arguments must be colocated on the same device" diff --git a/tests/pjit_test.py b/tests/pjit_test.py index d1ae70d92..7eccb8636 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1977,7 +1977,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with jax_array(True): with global_mesh: with self.assertRaisesRegex( - ValueError, "Pjit's devices and Array's devices should be equal"): + ValueError, "Received incompatible devices for pjitted computation"): pjit(lambda x: x)(input_array) def test_array_lower_compile(self): @@ -2061,15 +2061,15 @@ class ArrayPjitTest(jtu.JaxTestCase): m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') - a1, _ = create_array(input_shape, m1, spec) + a1 = jnp.arange(prod(input_shape)).reshape(input_shape) with jax_array(True): with m1: with self.assertRaisesRegex( - ValueError, "Pjit's devices and Array's devices should be equal"): + ValueError, "Received incompatible devices for pjitted computation"): pjit(lambda x, y: (x, y), - out_axis_resources=(NamedSharding(m1, spec), - NamedSharding(m2, spec)))(a1, a1) + out_axis_resources=(NamedSharding(m1, spec), + NamedSharding(m2, spec)))(a1, a1) def test_array_device_assignment_mismatch_in_and_out_shardings(self): input_shape = (8, 2) @@ -2077,15 +2077,15 @@ class ArrayPjitTest(jtu.JaxTestCase): m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') - a1, _ = create_array(input_shape, m2, spec) + a1 = jnp.arange(prod(input_shape)).reshape(input_shape) with jax_array(True): with m1: with self.assertRaisesRegex( - ValueError, "Pjit's devices and Array's devices should be equal"): + ValueError, "Received incompatible devices for pjitted computation"): pjit(lambda x, y: (x, y), - in_axis_resources=NamedSharding(m2, spec), - out_axis_resources=NamedSharding(m1, spec))(a1, a1) + in_axis_resources=NamedSharding(m2, spec), + out_axis_resources=NamedSharding(m1, spec))(a1, a1) def test_mixed_inputs(self): input_shape = (8, 2) @@ -2329,7 +2329,7 @@ class ArrayPjitTest(jtu.JaxTestCase): with jtu.create_global_mesh((2, 2), ('x', 'y')): with self.assertRaisesRegex( ValueError, - "Pjit's devices and Array's devices should be equal."): + "Received incompatible devices for pjitted computation"): pjit(lambda x, y: (x, y))(uarr, carr) @jax_array(True) @@ -2353,11 +2353,33 @@ class ArrayPjitTest(jtu.JaxTestCase): b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) with self.assertRaisesRegex( ValueError, - "Devices of all `Array` inputs and outputs should be the same. " - r"Got array device ids \[0\] on platform.*and " - r"another array's device ids \[1\] on platform"): + "Received incompatible devices for pjitted computation. Got argument " + r"x of.*\ with int.*\[3\] and device ids \[0\].*and argument " + r"y of.*\ with int.*\[3\] and device ids \[1\].*"): pjit(lambda x, y: (x, y))(a, b) + def test_pjit_pytree_inp_device_assignment_mismatch(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) + b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) + c = jax.device_put(np.arange(16).reshape(8, 2), + NamedSharding(mesh, P('x', 'y'))) + + msg = ("Received incompatible devices for pjitted computation. Got " + r"argument {} of.* with int.*\[3\] and device ids \[0\].*and " + r"argument {} of.* with int.*\[8,2\] and device ids " + r"\[0, 1, 2, 3\].*") + + with self.assertRaisesRegex( + ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')): + pjit(lambda tuple_inp: tuple_inp)((a, (c, (b)))) + + with self.assertRaisesRegex( + ValueError, msg.format(r"dict_inp\['a'\]\['b'\]\['c'\]", + r"dict_inp\['a'\]\['b'\]\['g'\]")): + inp = {'d': a, 'z': a, 'a': {'f': a, 'y': b, 'b': {'g': c, 'c': a}}} + pjit(lambda dict_inp: dict_inp)(inp) + @jax_array(True) def test_same_out_sharding_id(self): shape = (8, 2) @@ -2475,8 +2497,13 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax_array(True) def test_jit_with_sharding_constraint_committed_inp_error(self): + if not jax.config.jax_jit_pjit_api_merge or not jax.config.jax_array: + self.skipTest('Requires jit-pjit merge and jax.Array') + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + @jax.jit def sharded_inp(inp): return jax.lax.with_sharding_constraint( @@ -2485,11 +2512,30 @@ class ArrayPjitTest(jtu.JaxTestCase): committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0]) with self.assertRaisesRegex( ValueError, - "Devices of all `Array` inputs and outputs should be the same"): + "Received incompatible devices for jitted computation. Got argument " + r"inp of.*sharded_inp with bfloat16\[8,2\] and device ids \[0\].*" + r"with_sharding_constraint.*with device ids \[0, 1, 2, 3\].*"): sharded_inp(committed_inp) + @pjit + def my_nested_pjit(inp1, inp2, inp3): + @partial(pjit, in_axis_resources=(s, s, s), + out_axis_resources=(s, s, s)) + def f(x, y, z): + return x * 2, y * 2, z * 2 + return f(inp1, inp2, inp3) + with self.assertRaisesRegex( + ValueError, + "Received incompatible devices for pjitted computation. Got argument " + r"inp1 of.*my_nested_pjit with bfloat16\[8,2\] and device ids \[0\].*" + r"nested pjit.*with device ids \[0, 1, 2, 3\].*"): + my_nested_pjit(committed_inp, committed_inp, committed_inp) + @jax_array(True) def test_jit_device_with_sharding_constraint_error(self): + if not jax.config.jax_jit_pjit_api_merge or not jax.config.jax_array: + self.skipTest('Requires jax.Array and jit-pjit merge.') + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0]) @@ -2497,18 +2543,11 @@ class ArrayPjitTest(jtu.JaxTestCase): out = jnp.zeros(shape, jnp.bfloat16) return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec)) - # This split is needed because original `jit` adds `device` as a - # `devices_from_context` whereas `pjit` passes it as an in_sharding. - if jax.config.jax_jit_pjit_api_merge: - error_msg = ("Devices of all `Array` inputs and outputs should be the same. " - r"Got array device ids \[0\] on platform.*and " - r"another array's device ids \[0, 1, 2, 3\] on platform") - else: - error_msg = ("Pjit's devices and Array's devices should be equal. " - r"Got Pjit's device ids \[0\] on platform.*and " - r"Array's device ids \[0, 1, 2, 3\] on platform") - - with self.assertRaisesRegex(ValueError, error_msg): + with self.assertRaisesRegex( + ValueError, + "Received incompatible devices for jitted computation. Got explicit " + r"output sharding with device ids \[0\].*with_sharding_constraint.*with " + r"device ids \[0, 1, 2, 3\].*"): sharded_zeros((4096, 3072), P('x', 'y')) @jax_array(True)