From 5046cedbfc2a4263739065982e5f4966cad5fdc0 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Thu, 13 Jun 2024 13:09:35 -0700 Subject: [PATCH] Make `pxla.shard_arg` batch calls to `xc.copy_array_to_devices_with_sharding` This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so. Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays. PiperOrigin-RevId: 643097852 --- jax/_src/api.py | 3 +- jax/_src/array.py | 64 +++++++++++++++++++--------- jax/_src/dispatch.py | 2 +- jax/_src/earray.py | 9 ++-- jax/_src/interpreters/pxla.py | 80 +++++++++++++++++++++++------------ jax/_src/pjit.py | 4 +- jax/_src/prng.py | 9 ++-- jax/_src/test_util.py | 35 ++++++++++----- jax/interpreters/pxla.py | 1 - tests/lax_test.py | 15 ++++--- 10 files changed, 147 insertions(+), 75 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index e1498b65e..161390688 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1807,7 +1807,8 @@ def _cpp_pmap( cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - pxla.shard_arg, pytree_registry=tree_util.default_registry) + lambda x, s: pxla.shard_args([s], [x])[0], + pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) diff --git a/jax/_src/array.py b/jax/_src/array.py index 697d9b994..514c10f7f 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -40,6 +40,7 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension as xe +from jax._src.lib import xla_extension_version from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, @@ -1068,7 +1069,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): if not candidates_list: # This array isn't sharded correctly. Reshard it via host roundtrip. # TODO(skye): more efficient reshard? - return pxla.shard_arg(x._value, sharding, canonicalize=False) + return pxla.shard_args([sharding], [x._value], canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -1088,25 +1089,50 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _array_shard_arg(x, sharding): - x._check_if_deleted() +def _array_shard_arg(xs, shardings): + results = [] + batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] + for i, (x, sharding) in enumerate(safe_zip(xs, shardings)): + x._check_if_deleted() - indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) - if not x.is_fully_addressable: - if same_indices: - return x + indices, same_indices = _sharding_indices_and_eq( + x.sharding, x.shape, sharding) + if not x.is_fully_addressable: + if same_indices: + results.append(x) + else: + raise NotImplementedError( + "Cannot reshard an input that is not fully addressable") else: - raise NotImplementedError( - "Cannot reshard an input that is not fully addressable") + devices = sharding._addressable_device_assignment + if same_indices: + # Add a placeholder result that will be filled in later. + results.append(None) + # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. + batch_xs.append(x) + batch_devs.append(list(devices)) + batch_shardings.append(sharding) + batch_indices.append(i) + # Resharding starts here: + elif dispatch.is_single_device_sharding(x.sharding): + results.append(shard_device_array(x, devices, indices, sharding)) + else: + results.append( + shard_sharded_device_array_slow_path(x, devices, indices, sharding)) + + if xla_extension_version < 271: + copy_outs = [ + xc.copy_array_to_devices_with_sharding(x, d, s) # pytype: disable=module-attr + for x, d, s in safe_zip(batch_xs, batch_devs, batch_shardings) + ] else: - devices = sharding._addressable_device_assignment - if same_indices: - return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding) - # Resharding starts here: - if dispatch.is_single_device_sharding(x.sharding): - return shard_device_array(x, devices, indices, sharding) - else: - return shard_sharded_device_array_slow_path(x, devices, indices, sharding) + copy_outs = xc.batched_copy_array_to_devices_with_sharding( + batch_xs, batch_devs, batch_shardings) + for i, copy_out in safe_zip(batch_indices, copy_outs): + assert results[i] is None + results[i] = copy_out + return results + pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg @@ -1139,8 +1165,8 @@ pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler # Token handlers -def _token_shard_arg(x, sharding): - return _array_shard_arg(x._buf, sharding) +def _token_shard_arg(xs, shardings): + return _array_shard_arg([x._buf for x in xs], shardings) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2836fd9b2..9f2fae105 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -324,7 +324,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool): result_handler = pxla.global_aval_to_result_handler(aval, s, committed) - return result_handler(pxla.shard_arg(x, s)) + return result_handler(pxla.shard_args([s], [x])[0]) def _override_get_device_assignment(sharding, *args, **kwargs): da = sharding._device_assignment diff --git a/jax/_src/earray.py b/jax/_src/earray.py index fcf0e9c3c..f4b5e232b 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -98,10 +98,11 @@ class EArray(basearray.Array): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(x, sharding): - arr = x._data - phys_sharding = sharding_impls.physical_sharding(x.aval, sharding) - return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) +def _earray_shard_arg_handler(xs, shardings): + arrs = [x._data for x in xs] + phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7ddac0e6f..7e9d15c7c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -17,6 +17,7 @@ from __future__ import annotations import enum from contextlib import contextmanager +import collections from collections import namedtuple from collections.abc import Sequence, Iterable import dataclasses @@ -108,18 +109,40 @@ ShardingSpec = sharding_specs.ShardingSpec def identity(x): return x -def shard_arg(arg, sharding, canonicalize=True): - if canonicalize: - arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)](arg, sharding) - - @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], args - ) -> Sequence[jax.Array]: - return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)] +def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]: + # Fast path for one argument. + if len(args) == 1: + arg = args[0] + if canonicalize: + arg = xla.canonicalize_dtype(arg) + return shard_arg_handlers[type(arg)]([arg], shardings) -shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {} + # type(arg) -> (indices, args, shardings) + batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore + for i, (arg, sharding) in enumerate(safe_zip(args, shardings)): + if canonicalize: + arg = xla.canonicalize_dtype(arg) + batch = batches[type(arg)] + batch[0].append(i) + batch[1].append(arg) + batch[2].append(sharding) + + # Call `shard_arg_handlers` per batch and build a flat list of arrays returned + # from each call in the same order as `args`. Since `batches` is grouped by + # types, we cannot simply flatten the results and we have to use the original + # indices to put each array back to its original position. + results: list[jax.Array | None] = [None] * len(args) + for t, (indices, a, s) in batches.items(): + outs = shard_arg_handlers[t](a, s) + for i, out in safe_zip(indices, outs): + results[i] = out + + assert all(result is not None for result in results) + return results + + +shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {} @lru_cache(maxsize=1024) @@ -127,31 +150,34 @@ def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _masked_array_error(x, sharding): +def _masked_array_error(xs, shardings): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(x, sharding): - devices = sharding._addressable_device_assignment - if x.dtype == dtypes.float0: - x = np.zeros(x.shape, dtype=np.dtype(bool)) - aval = api_util.shaped_abstractify(x) - if sharding.is_fully_replicated: - shards = [x] * len(devices) - else: - indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) - shards = [x[i] for i in indices] - return batched_device_put(aval, sharding, shards, devices) +def _shard_array(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + devices = sharding._addressable_device_assignment + if x.dtype == dtypes.float0: + x = np.zeros(x.shape, dtype=np.dtype(bool)) + aval = api_util.shaped_abstractify(x) + if sharding.is_fully_replicated: + shards = [x] * len(devices) + else: + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + shards = [x[i] for i in indices] + results.append(batched_device_put(aval, sharding, shards, devices)) + return results for _t in array_types: shard_arg_handlers[_t] = _shard_array -def _shard_darray(x, sharding): - return shard_arg(x._data, sharding) +def _shard_darray(xs, shardings): + return shard_args(shardings, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(x, sharding): - return shard_arg(x._buf, sharding) +def _shard_mutable_array(xs, shardings): + return shard_args(shardings, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, @@ -3151,7 +3177,7 @@ class MeshExecutable(stages.XlaExecutable): return xc._xla.pjit( self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, shard_arg) + tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0]) def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fbf7b015b..9e5b54ce7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -343,7 +343,7 @@ def _cpp_pjit(jit_info: PjitInfo): fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, jit_info.donate_argnums, tree_util.dispatch_registry, - pxla.shard_arg, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], _get_cpp_global_cache(jit_info.has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) @@ -1636,7 +1636,7 @@ def _pjit_call_impl(*args, jaxpr, return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], donated_argnums, tree_util.dispatch_registry, - pxla.shard_arg, + lambda x, sharding: pxla.shard_args([sharding], [x])[0], _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 4d4a045a3..bcbbe1790 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -467,10 +467,11 @@ xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(x: PRNGKeyArray, sharding): - arr = x._base_array - phys_sharding = physical_sharding(x.aval, sharding) - return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding) +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings): + arrs = [x._base_array for x in xs] + phys_shardings = [physical_sharding(x.aval, sharding) + for x, sharding in zip(xs, shardings)] + return pxla.shard_args(phys_shardings, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7e20e5911..7faa4c698 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -50,6 +50,7 @@ from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, @@ -244,18 +245,32 @@ def count_primitive_compiles(): @contextmanager def count_device_put_fast_path_hit(): - original_fn = xc.copy_array_to_devices_with_sharding - count = [0] + if xla_extension_version < 271: + original_fn = xc.copy_array_to_devices_with_sharding + count = [0] - def copy_array_to_devices_with_sharding_and_count(*args, **kwargs): - count[0] += 1 - return original_fn(*args, **kwargs) + def copy_array_to_devices_with_sharding_and_count(*args, **kwargs): + count[0] += 1 + return original_fn(*args, **kwargs) - xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count - try: - yield count - finally: - xc.copy_array_to_devices_with_sharding = original_fn + xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count + try: + yield count + finally: + xc.copy_array_to_devices_with_sharding = original_fn + else: + original_fn = xc.batched_copy_array_to_devices_with_sharding + count = [0] + + def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs): + count[0] += 1 + return original_fn(*args, **kwargs) + + xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count + try: + yield count + finally: + xc.batched_copy_array_to_devices_with_sharding = original_fn @contextmanager diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index a515f2293..c5aa31a53 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -23,7 +23,6 @@ from jax._src.interpreters.pxla import ( global_avals_to_results_handler as global_avals_to_results_handler, global_result_handlers as global_result_handlers, parallel_callable as parallel_callable, - shard_arg as shard_arg, shard_args as shard_args, xla_pmap_p as xla_pmap_p, ) diff --git a/tests/lax_test.py b/tests/lax_test.py index b69d623e7..ce1a2d4ff 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -45,7 +45,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning +from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map config.parse_flags_with_absl() @@ -3394,11 +3394,14 @@ class FooArray: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(x, sharding): - device, = sharding._addressable_device_assignment - aval = core.raise_to_shaped(core.get_aval(x.data)) - return pxla.batched_device_put( - aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]) +def shard_foo_array_handler(xs, shardings): + results = [] + for x, sharding in safe_zip(xs, shardings): + device, = sharding._addressable_device_assignment + aval = core.raise_to_shaped(core.get_aval(x.data)) + results.append(pxla.batched_device_put( + aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])) + return results def foo_array_constant_handler(x): return array._array_mlir_constant_handler(x.data)