mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Improve the error message which is raised from _get_and_check_device_assignment
.
Before: ``` ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU ``` After: ``` ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp) ``` PiperOrigin-RevId: 508746961
This commit is contained in:
parent
57900d7ef2
commit
1526c3e20c
@ -689,13 +689,19 @@ def _cpp_jit(
|
||||
inline=inline,
|
||||
keep_unused=keep_unused))
|
||||
execute = None
|
||||
if isinstance(top_trace, core.EvalTrace) and not (
|
||||
jax.config.jax_debug_nans or jax.config.jax_debug_infs):
|
||||
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
|
||||
out_flat = call_bind_continuation(execute(*args_flat))
|
||||
else:
|
||||
out_flat = call_bind_continuation(
|
||||
top_trace.process_call(primitive, fun_, tracers, params))
|
||||
try:
|
||||
if isinstance(top_trace, core.EvalTrace) and not (
|
||||
jax.config.jax_debug_nans or jax.config.jax_debug_infs):
|
||||
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
|
||||
out_flat = call_bind_continuation(execute(*args_flat))
|
||||
else:
|
||||
out_flat = call_bind_continuation(
|
||||
top_trace.process_call(primitive, fun_, tracers, params))
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
msg = pjit._device_assignment_mismatch_error(
|
||||
fun, fails, in_tree, args_flat, 'jit')
|
||||
raise ValueError(msg) from None
|
||||
out_pytree_def = out_tree()
|
||||
out = tree_unflatten(out_pytree_def, out_flat)
|
||||
|
||||
|
@ -45,6 +45,7 @@ from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import path
|
||||
from jax._src import profiler
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -553,20 +554,22 @@ def jaxpr_has_primitive(jaxpr, prim_name: str):
|
||||
return False
|
||||
|
||||
|
||||
def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]:
|
||||
def jaxpr_shardings(jaxpr) -> Iterator[Tuple[jax.sharding.XLACompatibleSharding, str]]:
|
||||
from jax.experimental import pjit, shard_map
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is pjit.sharding_constraint_p:
|
||||
yield eqn.params['sharding']
|
||||
yield (eqn.params['sharding'], source_info_util.summarize(eqn.source_info))
|
||||
elif eqn.primitive is pjit.pjit_p:
|
||||
yield from eqn.params['in_shardings']
|
||||
yield from eqn.params['out_shardings']
|
||||
source_info = source_info_util.summarize(eqn.source_info)
|
||||
yield from ((i, source_info) for i in eqn.params['in_shardings'])
|
||||
yield from ((o, source_info) for o in eqn.params['out_shardings'])
|
||||
elif eqn.primitive is shard_map.shard_map_p:
|
||||
source_info = source_info_util.summarize(eqn.source_info)
|
||||
def _names_to_pspec(names):
|
||||
ndmin = max(names) + 1 if names else 0
|
||||
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
|
||||
yield from (NamedSharding(eqn.params['mesh'], _names_to_pspec(names))
|
||||
yield from ((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
|
||||
for names in eqn.params['in_names'])
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_shardings(subjaxpr)
|
||||
|
@ -2715,55 +2715,99 @@ def check_if_any_auto(
|
||||
return True
|
||||
return False
|
||||
|
||||
class MismatchType(enum.Enum):
|
||||
ARG_SHARDING = 0
|
||||
OUT_SHARDING = 1
|
||||
SHARDING_INSIDE_COMPUTATION = 2
|
||||
CONTEXT_DEVICES = 3
|
||||
IN_SHARDING = 4
|
||||
|
||||
def __str__(self):
|
||||
if self.name == 'IN_SHARDING':
|
||||
return 'explicit input sharding'
|
||||
elif self.name == 'OUT_SHARDING':
|
||||
return 'explicit output sharding'
|
||||
elif self.name == 'SHARDING_INSIDE_COMPUTATION':
|
||||
return 'with_sharding_constraint or nested pjit or shard_map'
|
||||
elif self.name == 'CONTEXT_DEVICES':
|
||||
return 'devices'
|
||||
return f'{self.name}'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeviceAssignmentMismatch:
|
||||
da: Sequence[xc.Device]
|
||||
m_type: MismatchType
|
||||
source_info: Optional[str]
|
||||
|
||||
@property
|
||||
def device_ids(self) -> Sequence[int]:
|
||||
return [d.id for d in self.da]
|
||||
|
||||
@property
|
||||
def platform(self) -> str:
|
||||
return self.da[0].platform.upper()
|
||||
|
||||
def _maybe_api_name(self, api_name) -> str:
|
||||
return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else ""
|
||||
|
||||
@property
|
||||
def source_info_str(self):
|
||||
return "" if self.source_info is None else f" at {self.source_info}"
|
||||
|
||||
@property
|
||||
def _dev_ids_plat_str(self):
|
||||
return f"device ids {self.device_ids} on platform {self.platform}"
|
||||
|
||||
def _str(self, api_name):
|
||||
return (f"{self._maybe_api_name(api_name)} {self.m_type} with "
|
||||
f"{self._dev_ids_plat_str}{self.source_info_str}")
|
||||
|
||||
|
||||
class DeviceAssignmentMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
ShardingInfo = Tuple[Union[sharding_internal.XLACompatibleSharding,
|
||||
UnspecifiedValue, AUTOAxisResource],
|
||||
MismatchType, Optional[str]]
|
||||
|
||||
def _get_and_check_device_assignment(
|
||||
shardings: Iterable[Union[sharding_internal.XLACompatibleSharding,
|
||||
UnspecifiedValue, AUTOAxisResource]],
|
||||
devices: Optional[Sequence[xc.Device]]
|
||||
shardings: Iterable[ShardingInfo],
|
||||
devices: Optional[Sequence[xc.Device]],
|
||||
) -> Tuple[xla.Backend, Sequence[xc.Device]]:
|
||||
from jax._src.api import local_devices
|
||||
|
||||
first_device_assignment = None
|
||||
first_sharding_info = None
|
||||
if devices is None:
|
||||
devices = []
|
||||
else:
|
||||
devices = list(devices)
|
||||
|
||||
for i in shardings:
|
||||
for i, s_type, source_info in shardings:
|
||||
if is_auto(i) or _is_unspecified(i):
|
||||
continue
|
||||
# Assign `first_device_assignment` after `AUTO` and `UNSPECIFIED` have been
|
||||
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
|
||||
# skipped.
|
||||
if first_device_assignment is None:
|
||||
first_device_assignment = list(i._device_assignment) # type: ignore
|
||||
if first_sharding_info is None:
|
||||
first_sharding_info = (list(i._device_assignment), s_type, source_info) # type: ignore
|
||||
arr_device_assignment = list(i._device_assignment) # type: ignore
|
||||
if not devices:
|
||||
if first_device_assignment != arr_device_assignment:
|
||||
p1 = first_device_assignment[0].platform.upper()
|
||||
fda_ids = [d.id for d in first_device_assignment]
|
||||
a_ids = [d.id for d in arr_device_assignment]
|
||||
p2 = arr_device_assignment[0].platform.upper()
|
||||
raise ValueError(
|
||||
"Devices of all `Array` inputs and outputs should be "
|
||||
"the same. "
|
||||
f"Got array device ids {fda_ids} on platform {p1} and "
|
||||
f"another array's device ids {a_ids} on platform {p2}")
|
||||
if first_sharding_info[0] != arr_device_assignment:
|
||||
raise DeviceAssignmentMismatchError([
|
||||
DeviceAssignmentMismatch(*first_sharding_info),
|
||||
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
||||
else:
|
||||
if devices != arr_device_assignment:
|
||||
p1 = devices[0].platform.upper()
|
||||
dev_ids = [d.id for d in devices]
|
||||
a_ids = [d.id for d in arr_device_assignment]
|
||||
p2 = arr_device_assignment[0].platform.upper()
|
||||
raise ValueError(
|
||||
"Pjit's devices and Array's devices should be equal. "
|
||||
f"Got Pjit's device ids {dev_ids} on platform {p1} and "
|
||||
f"Array's device ids {a_ids} on platform {p2}")
|
||||
if first_device_assignment is None and devices:
|
||||
raise DeviceAssignmentMismatchError([
|
||||
DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None),
|
||||
DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)])
|
||||
if first_sharding_info is None and devices:
|
||||
final_device_assignment = devices
|
||||
elif first_device_assignment is None:
|
||||
elif first_sharding_info is None:
|
||||
final_device_assignment = [config.jax_default_device or local_devices()[0]]
|
||||
else:
|
||||
final_device_assignment = first_device_assignment
|
||||
final_device_assignment = first_sharding_info[0]
|
||||
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
|
||||
|
||||
|
||||
@ -2807,8 +2851,12 @@ def lower_sharding_computation(
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
|
||||
backend, device_assignment = _get_and_check_device_assignment(it.chain(
|
||||
in_shardings, out_shardings, jaxpr_sharding), devices_from_context) # type: ignore
|
||||
backend, device_assignment = _get_and_check_device_assignment(
|
||||
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
|
||||
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], # type: ignore
|
||||
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) # type: ignore
|
||||
for js, source_info in jaxpr_sharding]),
|
||||
devices_from_context)
|
||||
|
||||
# TODO(yashkatariya): Make this logic work after DCE because there can be
|
||||
# equations inside the jaxpr that don't affect the output so whether the
|
||||
@ -2817,7 +2865,7 @@ def lower_sharding_computation(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not _is_unspecified(i) for i in in_shardings) or
|
||||
any(not _is_unspecified(js) for js in jaxpr_sharding) or
|
||||
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore
|
||||
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
|
||||
|
||||
in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import dataclasses
|
||||
from enum import IntEnum
|
||||
import inspect
|
||||
import numpy as np
|
||||
from collections import OrderedDict, Counter
|
||||
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
|
||||
@ -59,7 +60,7 @@ from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import prefix_errors
|
||||
from jax._src.tree_util import (prefix_errors, _generate_key_paths)
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wraps,
|
||||
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
|
||||
@ -111,14 +112,87 @@ def _check_all_or_none_unspecified(axis_resources, name):
|
||||
'`pjit._UNSPECIFIED`.')
|
||||
return unspecified
|
||||
|
||||
def _python_pjit_helper(infer_params_fn, *args, **kwargs):
|
||||
args_flat, _, params, _, out_tree, _ = infer_params_fn(*args, **kwargs)
|
||||
|
||||
def _try_infer_args(f, tree):
|
||||
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
|
||||
try:
|
||||
return inspect.signature(f).bind(*dummy_args)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _find_arg_mismatch(arg_list, fails, fun_name):
|
||||
first_err, second_err = fails
|
||||
mismatched_args_msg = []
|
||||
for name, inp_da, aval in arg_list:
|
||||
if first_err.m_type == pxla.MismatchType.ARG_SHARDING:
|
||||
if first_err.da == inp_da:
|
||||
mismatched_args_msg.append(
|
||||
(f"argument {name} of {fun_name} with {aval.str_short()} and "
|
||||
f"{first_err._dev_ids_plat_str}"))
|
||||
break
|
||||
|
||||
for name, inp_da, aval in arg_list:
|
||||
if second_err.m_type == pxla.MismatchType.ARG_SHARDING:
|
||||
if second_err.da == inp_da:
|
||||
mismatched_args_msg.append(
|
||||
(f"argument {name} of {fun_name} with {aval.str_short()} and "
|
||||
f"{second_err._dev_ids_plat_str}"))
|
||||
break
|
||||
return mismatched_args_msg
|
||||
|
||||
|
||||
def _device_assignment_mismatch_error(fun, fails, in_tree, args_flat, api_name):
|
||||
sig = _try_infer_args(fun, in_tree)
|
||||
args = tree_unflatten(in_tree, args_flat)
|
||||
args_aug = _generate_key_paths(args)
|
||||
|
||||
arg_list = []
|
||||
for arg_key, val in args_aug:
|
||||
ak, *rem_keys = arg_key.keys
|
||||
if sig is not None:
|
||||
loc = ''.join(k.pprint() for k in rem_keys)
|
||||
arg_name = f'{list(sig.arguments.keys())[ak.key]}{loc}'
|
||||
else:
|
||||
arg_name = ''
|
||||
da = val.sharding._device_assignment if hasattr(val, 'sharding') else None
|
||||
arg_list.append((arg_name, da, shaped_abstractify(val)))
|
||||
|
||||
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
||||
mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name)
|
||||
|
||||
if len(mismatched_args_msg) == 2:
|
||||
first, second = mismatched_args_msg # pylint: disable=unbalanced-tuple-unpacking
|
||||
extra_msg = f" Got {first} and {second}"
|
||||
elif len(mismatched_args_msg) == 1:
|
||||
first, second = fails
|
||||
# Choose the failure left which is not already covered by ARG_SHARDING.
|
||||
left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first
|
||||
extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}"
|
||||
else:
|
||||
first, second = fails
|
||||
extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}"
|
||||
msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}")
|
||||
return msg
|
||||
|
||||
|
||||
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
|
||||
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
|
||||
*args, **kwargs)
|
||||
for arg in args_flat:
|
||||
dispatch.check_arg(arg)
|
||||
out_flat = pjit_p.bind(*args_flat, **params)
|
||||
try:
|
||||
out_flat = pjit_p.bind(*args_flat, **params)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun, fails, in_tree, args_flat, api_name)
|
||||
raise ValueError(msg) from None
|
||||
outs = tree_unflatten(out_tree, out_flat)
|
||||
return outs, out_flat, out_tree, args_flat
|
||||
|
||||
|
||||
def _python_pjit(fun: Callable, infer_params_fn):
|
||||
|
||||
@wraps(fun)
|
||||
@ -126,7 +200,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
|
||||
def wrapped(*args, **kwargs):
|
||||
if config.jax_disable_jit:
|
||||
return fun(*args, **kwargs)
|
||||
return _python_pjit_helper(infer_params_fn, *args, **kwargs)[0]
|
||||
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
|
||||
|
||||
return wrapped
|
||||
|
||||
@ -153,7 +227,7 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
|
||||
infer_params_fn, *args, **kwargs)
|
||||
fun, infer_params_fn, *args, **kwargs)
|
||||
|
||||
executable = _read_most_recent_pjit_call_executable()
|
||||
|
||||
@ -1081,13 +1155,15 @@ def _resolve_in_shardings(
|
||||
if isinstance(arg_s, PmapSharding):
|
||||
continue
|
||||
if getattr(a, '_committed', True):
|
||||
committed_arg_shardings.append(arg_s)
|
||||
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
|
||||
|
||||
# Check if the device_assignment across inputs, outputs and arguments is the
|
||||
# same.
|
||||
pxla._get_and_check_device_assignment(
|
||||
it.chain(
|
||||
committed_arg_shardings, pjit_in_shardings, out_shardings),
|
||||
committed_arg_shardings,
|
||||
[(i, pxla.MismatchType.IN_SHARDING, None) for i in pjit_in_shardings],
|
||||
[(o, pxla.MismatchType.OUT_SHARDING, None) for o in out_shardings]),
|
||||
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
|
||||
|
||||
resolved_in_shardings = []
|
||||
|
@ -766,7 +766,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
return jax.jit(lambda x: x + 1, backend="cpu")(x)
|
||||
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
msg = 'Devices of all `Array` inputs and outputs should be the same'
|
||||
msg = 'Received incompatible devices for jitted computation'
|
||||
else:
|
||||
msg = ("Outer-jit backend specification .* must match explicit inner-jit "
|
||||
"backend specification cpu.")
|
||||
|
@ -104,7 +104,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
# in an error
|
||||
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
err_msg = "Devices of all `Array` inputs and outputs should be the same"
|
||||
err_msg = "Received incompatible devices for jitted computation"
|
||||
else:
|
||||
err_msg = "primitive arguments must be colocated on the same device"
|
||||
|
||||
|
@ -1977,7 +1977,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Pjit's devices and Array's devices should be equal"):
|
||||
ValueError, "Received incompatible devices for pjitted computation"):
|
||||
pjit(lambda x: x)(input_array)
|
||||
|
||||
def test_array_lower_compile(self):
|
||||
@ -2061,15 +2061,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
spec = P('x', 'y')
|
||||
|
||||
a1, _ = create_array(input_shape, m1, spec)
|
||||
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
|
||||
|
||||
with jax_array(True):
|
||||
with m1:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Pjit's devices and Array's devices should be equal"):
|
||||
ValueError, "Received incompatible devices for pjitted computation"):
|
||||
pjit(lambda x, y: (x, y),
|
||||
out_axis_resources=(NamedSharding(m1, spec),
|
||||
NamedSharding(m2, spec)))(a1, a1)
|
||||
out_axis_resources=(NamedSharding(m1, spec),
|
||||
NamedSharding(m2, spec)))(a1, a1)
|
||||
|
||||
def test_array_device_assignment_mismatch_in_and_out_shardings(self):
|
||||
input_shape = (8, 2)
|
||||
@ -2077,15 +2077,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
spec = P('x', 'y')
|
||||
|
||||
a1, _ = create_array(input_shape, m2, spec)
|
||||
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
|
||||
|
||||
with jax_array(True):
|
||||
with m1:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Pjit's devices and Array's devices should be equal"):
|
||||
ValueError, "Received incompatible devices for pjitted computation"):
|
||||
pjit(lambda x, y: (x, y),
|
||||
in_axis_resources=NamedSharding(m2, spec),
|
||||
out_axis_resources=NamedSharding(m1, spec))(a1, a1)
|
||||
in_axis_resources=NamedSharding(m2, spec),
|
||||
out_axis_resources=NamedSharding(m1, spec))(a1, a1)
|
||||
|
||||
def test_mixed_inputs(self):
|
||||
input_shape = (8, 2)
|
||||
@ -2329,7 +2329,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.create_global_mesh((2, 2), ('x', 'y')):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Pjit's devices and Array's devices should be equal."):
|
||||
"Received incompatible devices for pjitted computation"):
|
||||
pjit(lambda x, y: (x, y))(uarr, carr)
|
||||
|
||||
@jax_array(True)
|
||||
@ -2353,11 +2353,33 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Devices of all `Array` inputs and outputs should be the same. "
|
||||
r"Got array device ids \[0\] on platform.*and "
|
||||
r"another array's device ids \[1\] on platform"):
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
r"x of.*\<lambda\> with int.*\[3\] and device ids \[0\].*and argument "
|
||||
r"y of.*\<lambda\> with int.*\[3\] and device ids \[1\].*"):
|
||||
pjit(lambda x, y: (x, y))(a, b)
|
||||
|
||||
def test_pjit_pytree_inp_device_assignment_mismatch(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
|
||||
b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1])
|
||||
c = jax.device_put(np.arange(16).reshape(8, 2),
|
||||
NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
msg = ("Received incompatible devices for pjitted computation. Got "
|
||||
r"argument {} of.*<lambda> with int.*\[3\] and device ids \[0\].*and "
|
||||
r"argument {} of.*<lambda> with int.*\[8,2\] and device ids "
|
||||
r"\[0, 1, 2, 3\].*")
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')):
|
||||
pjit(lambda tuple_inp: tuple_inp)((a, (c, (b))))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, msg.format(r"dict_inp\['a'\]\['b'\]\['c'\]",
|
||||
r"dict_inp\['a'\]\['b'\]\['g'\]")):
|
||||
inp = {'d': a, 'z': a, 'a': {'f': a, 'y': b, 'b': {'g': c, 'c': a}}}
|
||||
pjit(lambda dict_inp: dict_inp)(inp)
|
||||
|
||||
@jax_array(True)
|
||||
def test_same_out_sharding_id(self):
|
||||
shape = (8, 2)
|
||||
@ -2475,8 +2497,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jax_array(True)
|
||||
def test_jit_with_sharding_constraint_committed_inp_error(self):
|
||||
if not jax.config.jax_jit_pjit_api_merge or not jax.config.jax_array:
|
||||
self.skipTest('Requires jit-pjit merge and jax.Array')
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
@jax.jit
|
||||
def sharded_inp(inp):
|
||||
return jax.lax.with_sharding_constraint(
|
||||
@ -2485,11 +2512,30 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
committed_inp = jax.device_put(jnp.zeros((8, 2), jnp.bfloat16), jax.devices()[0])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Devices of all `Array` inputs and outputs should be the same"):
|
||||
"Received incompatible devices for jitted computation. Got argument "
|
||||
r"inp of.*sharded_inp with bfloat16\[8,2\] and device ids \[0\].*"
|
||||
r"with_sharding_constraint.*with device ids \[0, 1, 2, 3\].*"):
|
||||
sharded_inp(committed_inp)
|
||||
|
||||
@pjit
|
||||
def my_nested_pjit(inp1, inp2, inp3):
|
||||
@partial(pjit, in_axis_resources=(s, s, s),
|
||||
out_axis_resources=(s, s, s))
|
||||
def f(x, y, z):
|
||||
return x * 2, y * 2, z * 2
|
||||
return f(inp1, inp2, inp3)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for pjitted computation. Got argument "
|
||||
r"inp1 of.*my_nested_pjit with bfloat16\[8,2\] and device ids \[0\].*"
|
||||
r"nested pjit.*with device ids \[0, 1, 2, 3\].*"):
|
||||
my_nested_pjit(committed_inp, committed_inp, committed_inp)
|
||||
|
||||
@jax_array(True)
|
||||
def test_jit_device_with_sharding_constraint_error(self):
|
||||
if not jax.config.jax_jit_pjit_api_merge or not jax.config.jax_array:
|
||||
self.skipTest('Requires jax.Array and jit-pjit merge.')
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0])
|
||||
@ -2497,18 +2543,11 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = jnp.zeros(shape, jnp.bfloat16)
|
||||
return jax.lax.with_sharding_constraint(out, NamedSharding(mesh, pspec))
|
||||
|
||||
# This split is needed because original `jit` adds `device` as a
|
||||
# `devices_from_context` whereas `pjit` passes it as an in_sharding.
|
||||
if jax.config.jax_jit_pjit_api_merge:
|
||||
error_msg = ("Devices of all `Array` inputs and outputs should be the same. "
|
||||
r"Got array device ids \[0\] on platform.*and "
|
||||
r"another array's device ids \[0, 1, 2, 3\] on platform")
|
||||
else:
|
||||
error_msg = ("Pjit's devices and Array's devices should be equal. "
|
||||
r"Got Pjit's device ids \[0\] on platform.*and "
|
||||
r"Array's device ids \[0, 1, 2, 3\] on platform")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, error_msg):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for jitted computation. Got explicit "
|
||||
r"output sharding with device ids \[0\].*with_sharding_constraint.*with "
|
||||
r"device ids \[0, 1, 2, 3\].*"):
|
||||
sharded_zeros((4096, 3072), P('x', 'y'))
|
||||
|
||||
@jax_array(True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user