Improve the error message which is raised from _get_and_check_device_assignment.

Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
This commit is contained in:
Yash Katariya 2023-02-10 13:53:43 -08:00 committed by jax authors
parent 57900d7ef2
commit 1526c3e20c
7 changed files with 252 additions and 80 deletions

View File

@ -689,13 +689,19 @@ def _cpp_jit(
inline=inline,
keep_unused=keep_unused))
execute = None
if isinstance(top_trace, core.EvalTrace) and not (
jax.config.jax_debug_nans or jax.config.jax_debug_infs):
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
out_flat = call_bind_continuation(execute(*args_flat))
else:
out_flat = call_bind_continuation(
top_trace.process_call(primitive, fun_, tracers, params))
try:
if isinstance(top_trace, core.EvalTrace) and not (
jax.config.jax_debug_nans or jax.config.jax_debug_infs):
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
out_flat = call_bind_continuation(execute(*args_flat))
else:
out_flat = call_bind_continuation(
top_trace.process_call(primitive, fun_, tracers, params))
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
msg = pjit._device_assignment_mismatch_error(
fun, fails, in_tree, args_flat, 'jit')
raise ValueError(msg) from None
out_pytree_def = out_tree()
out = tree_unflatten(out_pytree_def, out_flat)

View File

@ -45,6 +45,7 @@ from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import path
from jax._src import profiler
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src import util
@ -553,20 +554,22 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
return False
def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]:
def jaxpr_shardings(jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, str]]:
from jax.experimental import pjit, shard_map
for eqn in jaxpr.eqns:
if eqn.primitive is pjit.sharding_constraint_p:
yield eqn.params['sharding']
yield (eqn.params['sharding'], source_info_util.summarize(eqn.source_info))
elif eqn.primitive is pjit.pjit_p:
yield from eqn.params['in_shardings']
yield from eqn.params['out_shardings']
source_info = source_info_util.summarize(eqn.source_info)
yield from ((i, source_info) for i in eqn.params['in_shardings'])
yield from ((o, source_info) for o in eqn.params['out_shardings'])
elif eqn.primitive is shard_map.shard_map_p:
source_info = source_info_util.summarize(eqn.source_info)
def _names_to_pspec(names):
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
yield from (NamedSharding(eqn.params['mesh'], _names_to_pspec(names))
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
for names in eqn.params['in_names'])
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_shardings(subjaxpr)

View File

@ -2715,55 +2715,99 @@ def check_if_any_auto(
return True
return False
class MismatchType(enum.Enum):
ARG_SHARDING = 0
OUT_SHARDING = 1
SHARDING_INSIDE_COMPUTATION = 2
CONTEXT_DEVICES = 3
IN_SHARDING = 4
def __str__(self):
if self.name == 'IN_SHARDING':
return 'explicit input sharding'
elif self.name == 'OUT_SHARDING':
return 'explicit output sharding'
elif self.name == 'SHARDING_INSIDE_COMPUTATION':
return 'with_sharding_constraint or nested pjit or shard_map'
elif self.name == 'CONTEXT_DEVICES':
return 'devices'
return f'{self.name}'
@dataclasses.dataclass
class DeviceAssignmentMismatch:
da: Sequence[xc.Device]
m_type: MismatchType
source_info: Optional[str]
@property
def device_ids(self) -> Sequence[int]:
return [d.id for d in self.da]
@property
def platform(self) -> str:
return self.da[0].platform.upper()
def _maybe_api_name(self, api_name) -> str:
return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else ""
@property
def source_info_str(self):
return "" if self.source_info is None else f" at {self.source_info}"
@property
def _dev_ids_plat_str(self):
return f"device ids {self.device_ids} on platform {self.platform}"
def _str(self, api_name):
return (f"{self._maybe_api_name(api_name)} {self.m_type} with "
f"{self._dev_ids_plat_str}{self.source_info_str}")
class DeviceAssignmentMismatchError(Exception):
pass
ShardingInfo = Tuple[Union[sharding_internal.XLACompatibleSharding,
UnspecifiedValue, AUTOAxisResource],
MismatchType, Optional[str]]
def _get_and_check_device_assignment(
shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
UnspecifiedValue, AUTOAxisResource]],
devices: Optional[Sequence[xc.Device]]
shardings: Iterable[ShardingInfo],
devices: Optional[Sequence[xc.Device]],
) -> Tuple[xla.Backend, Sequence[xc.Device]]:
from jax._src.api import local_devices
first_device_assignment = None
first_sharding_info = None
if devices is None:
devices = []
else:
devices = list(devices)
for i in shardings:
for i, s_type, source_info in shardings:
if is_auto(i) or _is_unspecified(i):
continue
# Assign `first_device_assignment` after `AUTO` and `UNSPECIFIED` have been
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
# skipped.
if first_device_assignment is None:
first_device_assignment = list(i._device_assignment) # type: ignore
if first_sharding_info is None:
first_sharding_info = (list(i._device_assignment), s_type, source_info) # type: ignore
arr_device_assignment = list(i._device_assignment) # type: ignore
if not devices:
if first_device_assignment != arr_device_assignment:
p1 = first_device_assignment[0].platform.upper()
fda_ids = [d.id for d in first_device_assignment]
a_ids = [d.id for d in arr_device_assignment]
p2 = arr_device_assignment[0].platform.upper()
raise ValueError(
"Devices of all `Array` inputs and outputs should be "
"the same. "
f"Got array device ids {fda_ids} on platform {p1} and "
f"another array's device ids {a_ids} on platform {p2}")
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
DeviceAssignmentMismatch(*first_sharding_info),
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
else:
if devices != arr_device_assignment:
p1 = devices[0].platform.upper()
dev_ids = [d.id for d in devices]
a_ids = [d.id for d in arr_device_assignment]
p2 = arr_device_assignment[0].platform.upper()
raise ValueError(
"Pjit's devices and Array's devices should be equal. "
f"Got Pjit's device ids {dev_ids} on platform {p1} and "
f"Array's device ids {a_ids} on platform {p2}")
if first_device_assignment is None and devices:
raise DeviceAssignmentMismatchError([
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
if first_sharding_info is None and devices:
final_device_assignment = devices
elif first_device_assignment is None:
elif first_sharding_info is None:
final_device_assignment = [config.jax_default_device or local_devices()[0]]
else:
final_device_assignment = first_device_assignment
final_device_assignment = first_sharding_info[0]
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
@ -2807,8 +2851,12 @@ def lower_sharding_computation(
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
backend, device_assignment = _get_and_check_device_assignment(it.chain(
in_shardings, out_shardings, jaxpr_sharding), devices_from_context) # type: ignore
backend, device_assignment = _get_and_check_device_assignment(
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], # type: ignore
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) # type: ignore
for js, source_info in jaxpr_sharding]),
devices_from_context)
# TODO(yashkatariya): Make this logic work after DCE because there can be
# equations inside the jaxpr that don't affect the output so whether the
@ -2817,7 +2865,7 @@ def lower_sharding_computation(
devices_from_context or
len(device_assignment) > 1 or
any(not _is_unspecified(i) for i in in_shardings) or
any(not _is_unspecified(js) for js in jaxpr_sharding) or
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)

View File

@ -14,6 +14,7 @@
import dataclasses
from enum import IntEnum
import inspect
import numpy as np
from collections import OrderedDict, Counter
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
@ -59,7 +60,7 @@ from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import prefix_errors
from jax._src.tree_util import (prefix_errors, _generate_key_paths)
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
@ -111,14 +112,87 @@ def _check_all_or_none_unspecified(axis_resources, name):
'`pjit._UNSPECIFIED`.')
return unspecified
def _python_pjit_helper(infer_params_fn, *args, **kwargs):
args_flat, _, params, _, out_tree, _ = infer_params_fn(*args, **kwargs)
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
try:
return inspect.signature(f).bind(*dummy_args)
except (TypeError, ValueError):
return None
def _find_arg_mismatch(arg_list, fails, fun_name):
first_err, second_err = fails
mismatched_args_msg = []
for name, inp_da, aval in arg_list:
if first_err.m_type == pxla.MismatchType.ARG_SHARDING:
if first_err.da == inp_da:
mismatched_args_msg.append(
(f"argument {name} of {fun_name} with {aval.str_short()} and "
f"{first_err._dev_ids_plat_str}"))
break
for name, inp_da, aval in arg_list:
if second_err.m_type == pxla.MismatchType.ARG_SHARDING:
if second_err.da == inp_da:
mismatched_args_msg.append(
(f"argument {name} of {fun_name} with {aval.str_short()} and "
f"{second_err._dev_ids_plat_str}"))
break
return mismatched_args_msg
def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name):
sig = _try_infer_args(fun, in_tree)
args = tree_unflatten(in_tree, args_flat)
args_aug = _generate_key_paths(args)
arg_list = []
for arg_key, val in args_aug:
ak, *rem_keys = arg_key.keys
if sig is not None:
loc = ''.join(k.pprint() for k in rem_keys)
arg_name = f'{list(sig.arguments.keys())[ak.key]}{loc}'
else:
arg_name = ''
da = val.sharding._device_assignment if hasattr(val, 'sharding') else None
arg_list.append((arg_name, da, shaped_abstractify(val)))
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name)
if len(mismatched_args_msg) == 2:
first, second = mismatched_args_msg # pylint: disable=unbalanced-tuple-unpacking
extra_msg = f" Got {first} and {second}"
elif len(mismatched_args_msg) == 1:
first, second = fails
# Choose the failure left which is not already covered by ARG_SHARDING.
left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first
extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}"
else:
first, second = fails
extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}"
msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}")
return msg
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
*args, **kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
out_flat = pjit_p.bind(*args_flat, **params)
try:
out_flat = pjit_p.bind(*args_flat, **params)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
msg = _device_assignment_mismatch_error(
fun, fails, in_tree, args_flat, api_name)
raise ValueError(msg) from None
outs = tree_unflatten(out_tree, out_flat)
return outs, out_flat, out_tree, args_flat
def _python_pjit(fun: Callable, infer_params_fn):
@wraps(fun)
@ -126,7 +200,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
def wrapped(*args, **kwargs):
if config.jax_disable_jit:
return fun(*args, **kwargs)
return _python_pjit_helper(infer_params_fn, *args, **kwargs)[0]
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
return wrapped
@ -153,7 +227,7 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
@api_boundary
def cache_miss(*args, **kwargs):
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
infer_params_fn, *args, **kwargs)
fun, infer_params_fn, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable()
@ -1081,13 +1155,15 @@ def _resolve_in_shardings(
if isinstance(arg_s, PmapSharding):
continue
if getattr(a, '_committed', True):
committed_arg_shardings.append(arg_s)
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
# Check if the device_assignment across inputs, outputs and arguments is the
# same.
pxla._get_and_check_device_assignment(
it.chain(
committed_arg_shardings, pjit_in_shardings, out_shardings),
committed_arg_shardings,
[(i, pxla.MismatchType.IN_SHARDING, None) for i in pjit_in_shardings],
[(o, pxla.MismatchType.OUT_SHARDING, None) for o in out_shardings]),
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
resolved_in_shardings = []

View File

@ -766,7 +766,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
return jax.jit(lambda x: x + 1, backend="cpu")(x)
if jax.config.jax_jit_pjit_api_merge:
msg = 'Devices of all `Array` inputs and outputs should be the same'
msg = 'Received incompatible devices for jitted computation'
else:
msg = ("Outer-jit backend specification .* must match explicit inner-jit "
"backend specification cpu.")

View File

@ -104,7 +104,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
# in an error
if jax.config.jax_jit_pjit_api_merge:
err_msg = "Devices of all `Array` inputs and outputs should be the same"
err_msg = "Received incompatible devices for jitted computation"
else:
err_msg = "primitive arguments must be colocated on the same device"

View File

@ -1977,7 +1977,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jax_array(True):
with global_mesh:
with self.assertRaisesRegex(
ValueError, "Pjit's devices and Array's devices should be equal"):
ValueError, "Received incompatible devices for pjitted computation"):
pjit(lambda x: x)(input_array)
def test_array_lower_compile(self):
@ -2061,15 +2061,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
spec = P('x', 'y')
a1, _ = create_array(input_shape, m1, spec)
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
with jax_array(True):
with m1:
with self.assertRaisesRegex(
ValueError, "Pjit's devices and Array's devices should be equal"):
ValueError, "Received incompatible devices for pjitted computation"):
pjit(lambda x, y: (x, y),
out_axis_resources=(NamedSharding(m1, spec),
NamedSharding(m2, spec)))(a1, a1)
out_axis_resources=(NamedSharding(m1, spec),
NamedSharding(m2, spec)))(a1, a1)
def test_array_device_assignment_mismatch_in_and_out_shardings(self):
input_shape = (8, 2)
@ -2077,15 +2077,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
spec = P('x', 'y')
a1, _ = create_array(input_shape, m2, spec)
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
with jax_array(True):
with m1:
with self.assertRaisesRegex(
ValueError, "Pjit's devices and Array's devices should be equal"):
ValueError, "Received incompatible devices for pjitted computation"):
pjit(lambda x, y: (x, y),
in_axis_resources=NamedSharding(m2, spec),
out_axis_resources=NamedSharding(m1, spec))(a1, a1)
in_axis_resources=NamedSharding(m2, spec),
out_axis_resources=NamedSharding(m1, spec))(a1, a1)
def test_mixed_inputs(self):
input_shape = (8, 2)
@ -2329,7 +2329,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.create_global_mesh((2, 2), ('x', 'y')):
with self.assertRaisesRegex(
ValueError,
"Pjit's devices and Array's devices should be equal."):
"Received incompatible devices for pjitted computation"):
pjit(lambda x, y: (x, y))(uarr, carr)
@jax_array(True)
@ -2353,11 +2353,33 @@ class ArrayPjitTest(jtu.JaxTestCase):
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
with self.assertRaisesRegex(
ValueError,
"Devices of all `Array` inputs and outputs should be the same. "
r"Got array device ids \[0\] on platform.*and "
r"another array's device ids \[1\] on platform"):
"Received incompatible devices for pjitted computation. Got argument "
r"x of.*\<lambda\> with int.*\[3\] and device ids \[0\].*and argument "
r"y of.*\<lambda\> with int.*\[3\] and device ids \[1\].*"):
pjit(lambda x, y: (x, y))(a, b)
def test_pjit_pytree_inp_device_assignment_mismatch(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
c = jax.device_put(np.arange(16).reshape(8, 2),
NamedSharding(mesh, P('x', 'y')))
msg = ("Received incompatible devices for pjitted computation. Got "
r"argument {} of.*<lambda> with int.*\[3\] and device ids \[0\].*and "
r"argument {} of.*<lambda> with int.*\[8,2\] and device ids "
r"\[0, 1, 2, 3\].*")
with self.assertRaisesRegex(
ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')):
pjit(lambda tuple_inp: tuple_inp)((a, (c, (b))))
with self.assertRaisesRegex(
ValueError, msg.format(r"dict_inp\['a'\]\['b'\]\['c'\]",
r"dict_inp\['a'\]\['b'\]\['g'\]")):
inp = {'d': a, 'z': a, 'a': {'f': a, 'y': b, 'b': {'g': c, 'c': a}}}
pjit(lambda dict_inp: dict_inp)(inp)
@jax_array(True)
def test_same_out_sharding_id(self):
shape = (8, 2)
@ -2475,8 +2497,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_jit_with_sharding_constraint_committed_inp_error(self):
if not jax.config.jax_jit_pjit_api_merge or not jax.config.jax_array:
self.skipTest('Requires jit-pjit merge and jax.Array')
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
@jax.jit
def sharded_inp(inp):
return jax.lax.with_sharding_constraint(
@ -2485,11 +2512,30 @@ class ArrayPjitTest(jtu.JaxTestCase):
committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0])
with self.assertRaisesRegex(
ValueError,
"Devices of all `Array` inputs and outputs should be the same"):
"Received incompatible devices for jitted computation. Got argument "
r"inp of.*sharded_inp with bfloat16\[8,2\] and device ids \[0\].*"
r"with_sharding_constraint.*with device ids \[0, 1, 2, 3\].*"):
sharded_inp(committed_inp)
@pjit
def my_nested_pjit(inp1, inp2, inp3):
@partial(pjit, in_axis_resources=(s, s, s),
out_axis_resources=(s, s, s))
def f(x, y, z):
return x * 2, y * 2, z * 2
return f(inp1, inp2, inp3)
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for pjitted computation. Got argument "
r"inp1 of.*my_nested_pjit with bfloat16\[8,2\] and device ids \[0\].*"
r"nested pjit.*with device ids \[0, 1, 2, 3\].*"):
my_nested_pjit(committed_inp, committed_inp, committed_inp)
@jax_array(True)
def test_jit_device_with_sharding_constraint_error(self):
if not jax.config.jax_jit_pjit_api_merge or not jax.config.jax_array:
self.skipTest('Requires jax.Array and jit-pjit merge.')
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0])
@ -2497,18 +2543,11 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = jnp.zeros(shape, jnp.bfloat16)
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
# This split is needed because original `jit` adds `device` as a
# `devices_from_context` whereas `pjit` passes it as an in_sharding.
if jax.config.jax_jit_pjit_api_merge:
error_msg = ("Devices of all `Array` inputs and outputs should be the same. "
r"Got array device ids \[0\] on platform.*and "
r"another array's device ids \[0, 1, 2, 3\] on platform")
else:
error_msg = ("Pjit's devices and Array's devices should be equal. "
r"Got Pjit's device ids \[0\] on platform.*and "
r"Array's device ids \[0, 1, 2, 3\] on platform")
with self.assertRaisesRegex(ValueError, error_msg):
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for jitted computation. Got explicit "
r"output sharding with device ids \[0\].*with_sharding_constraint.*with "
r"device ids \[0, 1, 2, 3\].*"):
sharded_zeros((4096, 3072), P('x', 'y'))
@jax_array(True)