diff --git a/jax/_src/api.py b/jax/_src/api.py index e1498b65e..161390688 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1807,7 +1807,8 @@ def _cpp_pmap( cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - pxla.shard_arg, pytree_registry=tree_util.default_registry) + lambda x, s: pxla.shard_args([s], [x])[0], + pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) diff --git a/jax/_src/array.py b/jax/_src/array.py index 697d9b994..514c10f7f 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -40,6 +40,7 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe +from jax._src.lib import xla_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, @@ -1068,7 +1069,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): if not candidates_list: # This array isn't sharded correctly. Reshard it via host roundtrip. # TODO(skye): more efficient reshard? - return pxla.shard_arg(x._value, sharding, canonicalize=False) + return pxla.shard_args([sharding], [x._value], canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -1088,25 +1089,50 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _array_shard_arg(x, sharding): - x._check_if_deleted() +def _array_shard_arg(xs, shardings): + results = [] + batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] + for i, (x, sharding) in enumerate(safe_zip(xs, shardings)): + x._check_if_deleted() - indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) - if not x.is_fully_addressable: - if same_indices: - return x + indices, same_indices = _sharding_indices_and_eq( + x.sharding, x.shape, sharding) + if not x.is_fully_addressable: + if same_indices: + results.append(x) + else: + raise NotImplementedError( + "Cannot reshard an input that is not fully addressable") else: - raise NotImplementedError( - "Cannot reshard an input that is not fully addressable") + devices = sharding._addressable_device_assignment + if same_indices: + # Add a placeholder result that will be filled in later. + results.append(None) + # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. + batch_xs.append(x) + batch_devs.append(list(devices)) + batch_shardings.append(sharding) + batch_indices.append(i) + # Resharding starts here: + elif dispatch.is_single_device_sharding(x.sharding): + results.append(shard_device_array(x, devices, indices, sharding)) + else: + results.append( + shard_sharded_device_array_slow_path(x, devices, indices, sharding)) + + if xla_extension_version < 271: + copy_outs = [ + xc.copy_array_to_devices_with_sharding(x, d, s) # pytype: disable=module-attr + for x, d, s in safe_zip(batch_xs, batch_devs, batch_shardings) + ] else: - devices = sharding._addressable_device_assignment - if same_indices: - return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding) - # Resharding starts here: - if dispatch.is_single_device_sharding(x.sharding): - return shard_device_array(x, devices, indices, sharding) - else: - return shard_sharded_device_array_slow_path(x, devices, indices, sharding) + copy_outs = xc.batched_copy_array_to_devices_with_sharding( + batch_xs, batch_devs, batch_shardings) + for i, copy_out in safe_zip(batch_indices, copy_outs): + assert results[i] is None + results[i] = copy_out + return results + pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg @@ -1139,8 +1165,8 @@ pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler # Token handlers -def _token_shard_arg(x, sharding): - return _array_shard_arg(x._buf, sharding) +def _token_shard_arg(xs, shardings): + return _array_shard_arg([x._buf for x in xs], shardings) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2836fd9b2..9f2fae105 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -324,7 +324,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): result_handler = pxla.global_aval_to_result_handler(aval, s, committed) - return result_handler(pxla.shard_arg(x, s)) + return result_handler(pxla.shard_args([s], [x])[0]) def _override_get_device_assignment(sharding, *args, **kwargs): da = sharding._device_assignment diff --git a/jax/_src/earray.py b/jax/_src/earray.py index fcf0e9c3c..f4b5e232b 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -98,10 +98,11 @@ class EArray(basearray.Array): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(x, sharding): - arr = x._data - phys_sharding = sharding_impls.physical_sharding(x.aval, sharding) - return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) +def _earray_shard_arg_handler(xs, shardings): + arrs = [x._data for x in xs] + phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7ddac0e6f..7e9d15c7c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -17,6 +17,7 @@ from __future__ import annotations import enum from contextlib import contextmanager +import collections from collections import namedtuple from collections.abc import Sequence, Iterable import dataclasses @@ -108,18 +109,40 @@ ShardingSpec = sharding_specs.ShardingSpec def identity(x): return x -def shard_arg(arg, sharding, canonicalize=True): - if canonicalize: - arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)](arg, sharding) - - @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], args - ) -> Sequence[jax.Array]: - return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)] +def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]: + # Fast path for one argument. + if len(args) == 1: + arg = args[0] + if canonicalize: + arg = xla.canonicalize_dtype(arg) + return shard_arg_handlers[type(arg)]([arg], shardings) -shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {} + # type(arg) -> (indices, args, shardings) + batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore + for i, (arg, sharding) in enumerate(safe_zip(args, shardings)): + if canonicalize: + arg = xla.canonicalize_dtype(arg) + batch = batches[type(arg)] + batch[0].append(i) + batch[1].append(arg) + batch[2].append(sharding) + + # Call `shard_arg_handlers` per batch and build a flat list of arrays returned + # from each call in the same order as `args`. Since `batches` is grouped by + # types, we cannot simply flatten the results and we have to use the original + # indices to put each array back to its original position. + results: list[jax.Array | None] = [None] * len(args) + for t, (indices, a, s) in batches.items(): + outs = shard_arg_handlers[t](a, s) + for i, out in safe_zip(indices, outs): + results[i] = out + + assert all(result is not None for result in results) + return results + + +shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {} @lru_cache(maxsize=1024) @@ -127,31 +150,34 @@ def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _masked_array_error(x, sharding): +def _masked_array_error(xs, shardings): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(x, sharding): - devices = sharding._addressable_device_assignment - if x.dtype == dtypes.float0: - x = np.zeros(x.shape, dtype=np.dtype(bool)) - aval = api_util.shaped_abstractify(x) - if sharding.is_fully_replicated: - shards = [x] * len(devices) - else: - indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) - shards = [x[i] for i in indices] - return batched_device_put(aval, sharding, shards, devices) +def _shard_array(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + devices = sharding._addressable_device_assignment + if x.dtype == dtypes.float0: + x = np.zeros(x.shape, dtype=np.dtype(bool)) + aval = api_util.shaped_abstractify(x) + if sharding.is_fully_replicated: + shards = [x] * len(devices) + else: + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + shards = [x[i] for i in indices] + results.append(batched_device_put(aval, sharding, shards, devices)) + return results for _t in array_types: shard_arg_handlers[_t] = _shard_array -def _shard_darray(x, sharding): - return shard_arg(x._data, sharding) +def _shard_darray(xs, shardings): + return shard_args(shardings, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(x, sharding): - return shard_arg(x._buf, sharding) +def _shard_mutable_array(xs, shardings): + return shard_args(shardings, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, @@ -3151,7 +3177,7 @@ class MeshExecutable(stages.XlaExecutable): return xc._xla.pjit( self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, shard_arg) + tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0]) def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fbf7b015b..9e5b54ce7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -343,7 +343,7 @@ def _cpp_pjit(jit_info: PjitInfo): fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, jit_info.donate_argnums, tree_util.dispatch_registry, - pxla.shard_arg, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], _get_cpp_global_cache(jit_info.has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) @@ -1636,7 +1636,7 @@ def _pjit_call_impl(*args, jaxpr, return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], donated_argnums, tree_util.dispatch_registry, - pxla.shard_arg, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 4d4a045a3..bcbbe1790 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -467,10 +467,11 @@ xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(x: PRNGKeyArray, sharding): - arr = x._base_array - phys_sharding = physical_sharding(x.aval, sharding) - return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings): + arrs = [x._base_array for x in xs] + phys_shardings = [physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7e20e5911..7faa4c698 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -50,6 +50,7 @@ from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, @@ -244,18 +245,32 @@ def count_primitive_compiles(): @contextmanager def count_device_put_fast_path_hit(): - original_fn = xc.copy_array_to_devices_with_sharding - count = [0] + if xla_extension_version < 271: + original_fn = xc.copy_array_to_devices_with_sharding + count = [0] - def copy_array_to_devices_with_sharding_and_count(*args, **kwargs): - count[0] += 1 - return original_fn(*args, **kwargs) + def copy_array_to_devices_with_sharding_and_count(*args, **kwargs): + count[0] += 1 + return original_fn(*args, **kwargs) - xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count - try: - yield count - finally: - xc.copy_array_to_devices_with_sharding = original_fn + xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count + try: + yield count + finally: + xc.copy_array_to_devices_with_sharding = original_fn + else: + original_fn = xc.batched_copy_array_to_devices_with_sharding + count = [0] + + def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs): + count[0] += 1 + return original_fn(*args, **kwargs) + + xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count + try: + yield count + finally: + xc.batched_copy_array_to_devices_with_sharding = original_fn @contextmanager diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index a515f2293..c5aa31a53 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -23,7 +23,6 @@ from jax._src.interpreters.pxla import ( global_avals_to_results_handler as global_avals_to_results_handler, global_result_handlers as global_result_handlers, parallel_callable as parallel_callable, - shard_arg as shard_arg, shard_args as shard_args, xla_pmap_p as xla_pmap_p, ) diff --git a/tests/lax_test.py b/tests/lax_test.py index b69d623e7..ce1a2d4ff 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -45,7 +45,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning +from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map config.parse_flags_with_absl() @@ -3394,11 +3394,14 @@ class FooArray: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(x, sharding): - device, = sharding._addressable_device_assignment - aval = core.raise_to_shaped(core.get_aval(x.data)) - return pxla.batched_device_put( - aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]) +def shard_foo_array_handler(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + device, = sharding._addressable_device_assignment + aval = core.raise_to_shaped(core.get_aval(x.data)) + results.append(pxla.batched_device_put( + aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) + return results def foo_array_constant_handler(x): return array._array_mlir_constant_handler(x.data)