If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.

Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
This commit is contained in:
Yash Katariya 2024-08-19 15:10:00 -07:00 committed by jax authors
parent 292161ab4d
commit 6e1c23610d
13 changed files with 156 additions and 73 deletions

View File

@ -1827,7 +1827,7 @@ def _cpp_pmap(
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple,
lambda x, s: pxla.shard_args([s], [x])[0],
lambda x, s: pxla.shard_args([s], [None], [x])[0],
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)

View File

@ -1086,9 +1086,8 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
# Look up all buffers that contain the correct slice of the logical array.
candidates_list = candidates[hashed_index(idx)]
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return pxla.shard_args([sharding], [x._value], canonicalize=False)[0]
return pxla.shard_args([sharding], [None], [x._value],
canonicalize=False)[0]
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
@ -1097,7 +1096,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
break
else:
bufs.append(buf)
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
@ -1107,24 +1105,30 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
return dst_indices, tuple(src_indices) == tuple(dst_indices)
def _layout_eq(x, dst_layout, sharding):
if pxla.is_default_layout(dst_layout, sharding, x.aval):
return True
return x.layout.device_local_layout == dst_layout
def _array_shard_arg(xs, shardings):
def _array_shard_arg(xs, shardings, layouts):
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
for i, (x, sharding) in enumerate(safe_zip(xs, shardings)):
x._check_if_deleted()
indices, same_indices = _sharding_indices_and_eq(
x.sharding, x.shape, sharding)
for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)):
x._check_if_deleted()
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
same_layout = _layout_eq(x, layout, sharding)
if not x.is_fully_addressable:
if same_indices:
if same_indices and same_layout:
results.append(x)
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
devices = sharding._addressable_device_assignment
if same_indices:
if same_indices and same_layout:
# Add a placeholder result that will be filled in later.
results.append(None)
# Accumulate arguments to `batched_copy_array_to_devices_with_sharding`.
@ -1133,6 +1137,8 @@ def _array_shard_arg(xs, shardings):
batch_shardings.append(sharding)
batch_indices.append(i)
# Resharding starts here:
elif not same_layout:
results.append(api.device_put(x, Layout(layout, sharding)))
elif dispatch.is_single_device_sharding(x.sharding):
results.append(shard_device_array(x, devices, indices, sharding))
else:
@ -1145,8 +1151,6 @@ def _array_shard_arg(xs, shardings):
assert results[i] is None
results[i] = copy_out
return results
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
@ -1178,8 +1182,8 @@ pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
# Token handlers
def _token_shard_arg(xs, shardings):
return _array_shard_arg([x._buf for x in xs], shardings)
def _token_shard_arg(xs, shardings, layouts):
return _array_shard_arg([x._buf for x in xs], shardings, layouts)
pxla.shard_arg_handlers[core.Token] = _token_shard_arg

View File

@ -134,7 +134,7 @@ class RuntimeTokenSet(threading.local):
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = core.Token(pxla.shard_args([s], [tok])[0])
sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0])
self.current_tokens[eff] = sharded_tok
return sharded_tok
@ -515,7 +515,10 @@ def _batched_device_put_impl(
if shard_arg_xs:
# Batch shard_arg calls. Helps improve efficiency for backends that support
# efficient batch transfer.
shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs)
# device_put handles `Layout` via a different path, so just pass `None` as
# the layout here.
shard_arg_results = pxla.shard_args(
shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs)
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
assert isinstance(ys[i], _DeferredShardArg)
ys[i] = ys[i].result_handler(shard_arg_result)

View File

@ -104,11 +104,12 @@ class EArray(basearray.Array):
# TODO(mattjj): _set_array_base_attributes
def _earray_shard_arg_handler(xs, shardings):
def _earray_shard_arg_handler(xs, shardings, layouts):
arrs = [x._data for x in xs]
phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
for x, sharding in zip(xs, shardings)]
return pxla.shard_args(phys_shardings, arrs)
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
return pxla.shard_args(phys_shardings, layouts, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval

View File

@ -1000,7 +1000,7 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
return "auto"
if aval is core.abstract_token:
return "default"
return layout._to_xla_layout(aval.dtype) # type: ignore
return str(layout._to_xla_layout(aval.dtype)) # type: ignore
def _get_mem_kind(s: JSharding | None) -> str | None:

View File

@ -32,6 +32,7 @@ import numpy as np
import jax
from jax._src import api
from jax._src import api_util
from jax._src import compiler
from jax._src import config
@ -60,6 +61,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
@ -106,39 +108,67 @@ ShardingSpec = sharding_specs.ShardingSpec
def identity(x): return x
@profiler.annotate_function
def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]:
def shard_args(shardings: Sequence[JSharding], layouts, args,
canonicalize=True) -> Sequence[xc.ArrayImpl]:
# Fast path for one argument.
if len(args) == 1:
arg = args[0]
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)]([arg], shardings)
return shard_arg_handlers[type(arg)]([arg], shardings, layouts)
# type(arg) -> (indices, args, shardings)
batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore
for i, (arg, sharding) in enumerate(safe_zip(args, shardings)):
# type(arg) -> (list[indices], list[args], list[shardings])
batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore
for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)):
if canonicalize:
arg = xla.canonicalize_dtype(arg)
batch = batches[type(arg)]
batch[0].append(i)
batch[1].append(arg)
batch[2].append(sharding)
batch[3].append(layout)
# Call `shard_arg_handlers` per batch and build a flat list of arrays returned
# from each call in the same order as `args`. Since `batches` is grouped by
# types, we cannot simply flatten the results and we have to use the original
# indices to put each array back to its original position.
results: list[jax.Array | None] = [None] * len(args)
for t, (indices, a, s) in batches.items():
outs = shard_arg_handlers[t](a, s)
for t, (indices, a, s, l) in batches.items():
outs = shard_arg_handlers[t](a, s, l)
for i, out in safe_zip(indices, outs):
results[i] = out
assert all(result is not None for result in results)
return results
shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {}
shard_arg_handlers: dict[
Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]]
] = {}
def is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None:
return True
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
dtypes.issubdtype(aval.dtype, dtypes.extended)):
return True
if isinstance(curr_layout, AutoLayout):
return False
d = sharding._device_assignment[0]
shard_shape = sharding.shard_shape(aval.shape)
try:
# TODO(yashkatariya): Replace this with normal `==` check once CPU supports
# int4.
return is_user_xla_layout_equal(
curr_layout,
DeviceLocalLayout.from_pjrt_layout(
d.client.get_default_layout(aval.dtype, shard_shape, d)))
except xe.XlaRuntimeError as e:
msg, *_ = e.args
if isinstance(msg, str) and msg.startswith("UNIMPLEMENTED"):
return True
else:
raise
@lru_cache(maxsize=1024)
@ -146,18 +176,21 @@ def _get_replicated_slices(num_addressable_devices: int):
return ((slice(None),),) * num_addressable_devices
def _masked_array_error(xs, shardings):
def _masked_array_error(xs, shardings, layouts):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
def _shard_array(xs, shardings):
def _shard_np_array(xs, shardings, layouts):
results = []
for x, sharding in safe_zip(xs, shardings):
for x, sharding, layout in safe_zip(xs, shardings, layouts):
devices = sharding._addressable_device_assignment
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
if not is_default_layout(layout, sharding, aval):
results.append(api.device_put(x, Layout(layout, sharding)))
else:
if sharding.is_fully_replicated:
shards = [x] * len(devices)
else:
@ -166,14 +199,14 @@ def _shard_array(xs, shardings):
results.append(batched_device_put(aval, sharding, shards, devices))
return results
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
shard_arg_handlers[_t] = _shard_np_array
def _shard_darray(xs, shardings):
return shard_args(shardings, [x._data for x in xs])
def _shard_darray(xs, shardings, layouts):
return shard_args(shardings, layouts, [x._data for x in xs])
shard_arg_handlers[core.DArray] = _shard_darray
def _shard_mutable_array(xs, shardings):
return shard_args(shardings, [x._buf for x in xs])
def _shard_mutable_array(xs, shardings, layouts):
return shard_args(shardings, layouts, [x._buf for x in xs])
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
def batched_device_put(aval: core.ShapedArray,
@ -931,6 +964,7 @@ class UnloadedPmapExecutable:
handle_outs = local_avals_to_results_handler(self.local_output_avals,
self.output_shardings)
handle_args = InputsHandler(self.input_shardings,
[None] * len(self.input_shardings),
self.compiled.local_devices(), input_indices)
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
self.backend, handle_args, handle_outs,
@ -1109,12 +1143,15 @@ def _get_pmap_sharding(devices, specs):
class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
__slots__ = ("handler", "in_shardings", "in_layouts", "local_devices",
"input_indices")
def __init__(self, in_shardings, local_devices=None, input_indices=None):
self.handler = partial(shard_args, in_shardings)
self.local_devices = local_devices
def __init__(self, in_shardings, in_layouts, local_devices=None,
input_indices=None):
self.handler = partial(shard_args, in_shardings, in_layouts)
self.in_shardings = in_shardings
self.in_layouts = in_layouts
self.local_devices = local_devices
self.input_indices = input_indices
def __call__(self, input_buffers):
@ -1122,8 +1159,9 @@ class InputsHandler:
def __str__(self):
return ("InputsHandler(\n"
f"local_devices={self.local_devices},\n"
f"in_shardings={self.in_shardings},\n"
f"in_layouts={self.in_layouts},\n"
f"local_devices={self.local_devices},\n"
f"input_indices={self.input_indices})")
@ -1849,7 +1887,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval
if is_unspecified_or_auto(sharding):
return None
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
if dtypes.issubdtype(aval.dtype, dtypes.extended):
if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended):
return None
if not core.is_constant_shape(aval.shape):
return None
@ -2505,7 +2543,7 @@ def maybe_recover_user_shardings(
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
xl: DeviceLocalLayout) -> bool:
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
if isinstance(ul, DeviceLocalLayout) and not ul._tiling:
return ul.major_to_minor == xl.major_to_minor
else:
return ul == xl
@ -2742,7 +2780,7 @@ class UnloadedMeshExecutable:
pgle_profiler: profiler.PGLEProfiler | None
def build_unsafe_call(self):
handle_args = InputsHandler(self.input_shardings)
handle_args = InputsHandler(self.input_shardings, self.in_layouts)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed)
@ -2882,9 +2920,7 @@ class MeshExecutableFastpathData(NamedTuple):
out_avals: Sequence[ShapedArray]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
arg_handler_devices: Sequence[xc.Device]
arg_handler_indices: Sequence[tuple[Index | None, ...]]
in_device_local_layouts: Sequence[DeviceLocalLayout | None]
def reflatten_outputs_for_dispatch(out_tree, out_flat):
@ -2992,18 +3028,36 @@ class MeshExecutable(stages.XlaExecutable):
else s
for s, a in zip(self._in_shardings, self.in_avals)
]
in_dlls = get_layouts_for_fasthpath_data(
self._in_layouts, in_shardings, self.in_avals)
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree_dispatch, in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
self.unsafe_call.in_handler.local_devices,
self.unsafe_call.in_handler.input_indices)
in_dlls)
else:
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0])
tree_util.dispatch_registry, cc_shard_arg)
if xla_extension_version < 282:
def cc_shard_arg(x, sharding):
return shard_args([sharding], [None], [x])[0]
else:
def cc_shard_arg(x, sharding, layout): # type: ignore
return shard_args([sharding], [layout], [x])[0]
def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals):
in_dlls = []
for l, s, a in zip(in_layouts, in_shardings, in_avals):
if is_default_layout(l, s, a):
in_dlls.append(None)
else:
in_dlls.append(l)
return in_dlls
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -688,7 +688,7 @@ def _maybe_put(x):
aval = shaped_abstractify(x)
s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0])
result_handler = pxla.global_aval_to_result_handler(aval, s, False)
return result_handler(pxla.shard_args([s], [x]))
return result_handler(pxla.shard_args([s], [None], [x]))
else:
return x

View File

@ -69,7 +69,7 @@ class DeviceLocalLayout:
self._tiling == other._tiling and
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
def _to_xla_layout(self, dtype) -> str:
def _to_xla_layout(self, dtype) -> xc.Layout:
if self._tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
else:
@ -81,7 +81,7 @@ class DeviceLocalLayout:
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling,
sub_byte_size)
return str(xla_layout)
return xla_layout
def check_compatible_aval(self, aval_shape: Shape):
if len(self.major_to_minor) != len(aval_shape):

View File

@ -279,11 +279,12 @@ def _get_fastpath_data(
else s
for s, a in zip(executable._in_shardings, executable.in_avals)
]
in_dlls = pxla.get_layouts_for_fasthpath_data(
executable._in_layouts, in_shardings, executable.in_avals)
fastpath_data = pxla.MeshExecutableFastpathData(
executable.xla_executable, out_tree, in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
executable.unsafe_call.in_handler.local_devices,
executable.unsafe_call.in_handler.input_indices)
in_dlls)
else:
fastpath_data = None
return fastpath_data
@ -302,9 +303,7 @@ def _read_most_recent_pjit_call_executable(jaxpr):
def _read_pgle_profiler(jaxpr):
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(
jaxpr, None
)
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None)
def _cpp_pjit_evict_fn(self):
self._clear_cache()
@ -343,8 +342,7 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
cpp_pjit_f = xc._xla.pjit(
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
lambda x, sharding: pxla.shard_args([sharding], [x])[0],
jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(jit_info.has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
@ -1729,8 +1727,7 @@ def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry,
lambda x, sharding: pxla.shard_args([sharding], [x])[0],
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -466,11 +466,12 @@ xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings):
def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts):
arrs = [x._base_array for x in xs]
phys_shardings = [physical_sharding(x.aval, sharding)
for x, sharding in zip(xs, shardings)]
return pxla.shard_args(phys_shardings, arrs)
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
return pxla.shard_args(phys_shardings, layouts, arrs)
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler

View File

@ -3373,7 +3373,7 @@ class FooArray:
size = property(lambda self: self.data.size // 2)
ndim = property(lambda self: self.data.ndim - 1)
def shard_foo_array_handler(xs, shardings):
def shard_foo_array_handler(xs, shardings, layouts):
results = []
for x, sharding in safe_zip(xs, shardings):
device, = sharding._addressable_device_assignment

View File

@ -500,6 +500,29 @@ class LayoutTest(jtu.JaxTestCase):
'Layout passed to jit does not match the layout on the respective arg'):
g(arr)
def test_in_layouts_jit_jnp_input(self):
major_last_layout = DLL(major_to_minor=(1, 0))
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
f = jax.jit(lambda x: x + 1,
in_shardings=Layout(major_last_layout, sharding))
arr = jnp.arange(8 * 128).reshape(8, 128)
out = f(arr)
self.assertArraysEqual(out, arr + 1)
# cpp dispatch should call into shard_args from cpp.
out2 = f(arr)
self.assertArraysEqual(out2, arr + 1)
np_inp = np.arange(8 * 128).reshape(8, 128)
out3 = f(np_inp)
self.assertArraysEqual(out3, np_inp + 1)
# cpp dispatch should call into shard_args from cpp.
out4 = f(np_inp)
self.assertArraysEqual(out4, np_inp + 1)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -3015,7 +3015,7 @@ class ShardArgsTest(jtu.JaxTestCase):
x = np.arange(math.prod(shape)).reshape(shape)
arg = make_arg(x)
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
results = pxla.shard_args([sharding], [arg])
results = pxla.shard_args([sharding], [None], [arg])
self.assertEqual(len(results), 1)
if isinstance(results[0], array.ArrayImpl):
bufs = results[0]._arrays