Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding

This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
This commit is contained in:
Junwhan Ahn 2024-06-13 13:09:35 -07:00 committed by jax authors
parent 023bc7856b
commit 5046cedbfc
10 changed files with 147 additions and 75 deletions

View File

@ -1807,7 +1807,8 @@ def _cpp_pmap(
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple,
pxla.shard_arg, pytree_registry=tree_util.default_registry)
lambda x, s: pxla.shard_args([s], [x])[0],
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)
pmap_f = wraps(fun)(cpp_mapped_f)

View File

@ -40,6 +40,7 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
@ -1068,7 +1069,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return pxla.shard_arg(x._value, sharding, canonicalize=False)
return pxla.shard_args([sharding], [x._value], canonicalize=False)[0]
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
@ -1088,25 +1089,50 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
return dst_indices, tuple(src_indices) == tuple(dst_indices)
def _array_shard_arg(x, sharding):
x._check_if_deleted()
def _array_shard_arg(xs, shardings):
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
for i, (x, sharding) in enumerate(safe_zip(xs, shardings)):
x._check_if_deleted()
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
if not x.is_fully_addressable:
if same_indices:
return x
indices, same_indices = _sharding_indices_and_eq(
x.sharding, x.shape, sharding)
if not x.is_fully_addressable:
if same_indices:
results.append(x)
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
devices = sharding._addressable_device_assignment
if same_indices:
# Add a placeholder result that will be filled in later.
results.append(None)
# Accumulate arguments to `batched_copy_array_to_devices_with_sharding`.
batch_xs.append(x)
batch_devs.append(list(devices))
batch_shardings.append(sharding)
batch_indices.append(i)
# Resharding starts here:
elif dispatch.is_single_device_sharding(x.sharding):
results.append(shard_device_array(x, devices, indices, sharding))
else:
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
if xla_extension_version < 271:
copy_outs = [
xc.copy_array_to_devices_with_sharding(x, d, s) # pytype: disable=module-attr
for x, d, s in safe_zip(batch_xs, batch_devs, batch_shardings)
]
else:
devices = sharding._addressable_device_assignment
if same_indices:
return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding)
# Resharding starts here:
if dispatch.is_single_device_sharding(x.sharding):
return shard_device_array(x, devices, indices, sharding)
else:
return shard_sharded_device_array_slow_path(x, devices, indices, sharding)
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings)
for i, copy_out in safe_zip(batch_indices, copy_outs):
assert results[i] is None
results[i] = copy_out
return results
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
@ -1139,8 +1165,8 @@ pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
# Token handlers
def _token_shard_arg(x, sharding):
return _array_shard_arg(x._buf, sharding)
def _token_shard_arg(xs, shardings):
return _array_shard_arg([x._buf for x in xs], shardings)
pxla.shard_arg_handlers[core.Token] = _token_shard_arg

View File

@ -324,7 +324,7 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
def _put_x(x, s: Sharding, aval: core.AbstractValue, committed: bool):
result_handler = pxla.global_aval_to_result_handler(aval, s, committed)
return result_handler(pxla.shard_arg(x, s))
return result_handler(pxla.shard_args([s], [x])[0])
def _override_get_device_assignment(sharding, *args, **kwargs):
da = sharding._device_assignment

View File

@ -98,10 +98,11 @@ class EArray(basearray.Array):
# TODO(mattjj): _set_array_base_attributes
def _earray_shard_arg_handler(x, sharding):
arr = x._data
phys_sharding = sharding_impls.physical_sharding(x.aval, sharding)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
def _earray_shard_arg_handler(xs, shardings):
arrs = [x._data for x in xs]
phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
for x, sharding in zip(xs, shardings)]
return pxla.shard_args(phys_shardings, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import enum
from contextlib import contextmanager
import collections
from collections import namedtuple
from collections.abc import Sequence, Iterable
import dataclasses
@ -108,18 +109,40 @@ ShardingSpec = sharding_specs.ShardingSpec
def identity(x): return x
def shard_arg(arg, sharding, canonicalize=True):
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)](arg, sharding)
@profiler.annotate_function
def shard_args(shardings: Sequence[JSharding], args
) -> Sequence[jax.Array]:
return [shard_arg(arg, shardings[i]) for i, arg in enumerate(args)]
def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]:
# Fast path for one argument.
if len(args) == 1:
arg = args[0]
if canonicalize:
arg = xla.canonicalize_dtype(arg)
return shard_arg_handlers[type(arg)]([arg], shardings)
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}
# type(arg) -> (indices, args, shardings)
batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore
for i, (arg, sharding) in enumerate(safe_zip(args, shardings)):
if canonicalize:
arg = xla.canonicalize_dtype(arg)
batch = batches[type(arg)]
batch[0].append(i)
batch[1].append(arg)
batch[2].append(sharding)
# Call `shard_arg_handlers` per batch and build a flat list of arrays returned
# from each call in the same order as `args`. Since `batches` is grouped by
# types, we cannot simply flatten the results and we have to use the original
# indices to put each array back to its original position.
results: list[jax.Array | None] = [None] * len(args)
for t, (indices, a, s) in batches.items():
outs = shard_arg_handlers[t](a, s)
for i, out in safe_zip(indices, outs):
results[i] = out
assert all(result is not None for result in results)
return results
shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {}
@lru_cache(maxsize=1024)
@ -127,31 +150,34 @@ def _get_replicated_slices(num_addressable_devices: int):
return ((slice(None),),) * num_addressable_devices
def _masked_array_error(x, sharding):
def _masked_array_error(xs, shardings):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
def _shard_array(x, sharding):
devices = sharding._addressable_device_assignment
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
if sharding.is_fully_replicated:
shards = [x] * len(devices)
else:
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
shards = [x[i] for i in indices]
return batched_device_put(aval, sharding, shards, devices)
def _shard_array(xs, shardings):
results = []
for x, sharding in safe_zip(xs, shardings):
devices = sharding._addressable_device_assignment
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
if sharding.is_fully_replicated:
shards = [x] * len(devices)
else:
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
shards = [x[i] for i in indices]
results.append(batched_device_put(aval, sharding, shards, devices))
return results
for _t in array_types:
shard_arg_handlers[_t] = _shard_array
def _shard_darray(x, sharding):
return shard_arg(x._data, sharding)
def _shard_darray(xs, shardings):
return shard_args(shardings, [x._data for x in xs])
shard_arg_handlers[core.DArray] = _shard_darray
def _shard_mutable_array(x, sharding):
return shard_arg(x._buf, sharding)
def _shard_mutable_array(xs, shardings):
return shard_args(shardings, [x._buf for x in xs])
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
def batched_device_put(aval: core.ShapedArray,
@ -3151,7 +3177,7 @@ class MeshExecutable(stages.XlaExecutable):
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, shard_arg)
tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0])
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -343,7 +343,7 @@ def _cpp_pjit(jit_info: PjitInfo):
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg,
lambda x, sharding: pxla.shard_args([sharding], [x])[0],
_get_cpp_global_cache(jit_info.has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
@ -1636,7 +1636,7 @@ def _pjit_call_impl(*args, jaxpr,
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry,
pxla.shard_arg,
lambda x, sharding: pxla.shard_args([sharding], [x])[0],
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -467,10 +467,11 @@ xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArray, sharding):
arr = x._base_array
phys_sharding = physical_sharding(x.aval, sharding)
return pxla.shard_arg_handlers[type(arr)](arr, phys_sharding)
def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings):
arrs = [x._base_array for x in xs]
phys_shardings = [physical_sharding(x.aval, sharding)
for x, sharding in zip(xs, shardings)]
return pxla.shard_args(phys_shardings, arrs)
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler

View File

@ -50,6 +50,7 @@ from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
@ -244,18 +245,32 @@ def count_primitive_compiles():
@contextmanager
def count_device_put_fast_path_hit():
original_fn = xc.copy_array_to_devices_with_sharding
count = [0]
if xla_extension_version < 271:
original_fn = xc.copy_array_to_devices_with_sharding
count = [0]
def copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
def copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.copy_array_to_devices_with_sharding = original_fn
xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.copy_array_to_devices_with_sharding = original_fn
else:
original_fn = xc.batched_copy_array_to_devices_with_sharding
count = [0]
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.batched_copy_array_to_devices_with_sharding = original_fn
@contextmanager

View File

@ -23,7 +23,6 @@ from jax._src.interpreters.pxla import (
global_avals_to_results_handler as global_avals_to_results_handler,
global_result_handlers as global_result_handlers,
parallel_callable as parallel_callable,
shard_arg as shard_arg,
shard_args as shard_args,
xla_pmap_p as xla_pmap_p,
)

View File

@ -45,7 +45,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.util import NumpyComplexWarning
from jax._src.util import NumpyComplexWarning, safe_zip
from jax._src.tree_util import tree_map
config.parse_flags_with_absl()
@ -3394,11 +3394,14 @@ class FooArray:
size = property(lambda self: self.data.size // 2)
ndim = property(lambda self: self.data.ndim - 1)
def shard_foo_array_handler(x, sharding):
device, = sharding._addressable_device_assignment
aval = core.raise_to_shaped(core.get_aval(x.data))
return pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])
def shard_foo_array_handler(xs, shardings):
results = []
for x, sharding in safe_zip(xs, shardings):
device, = sharding._addressable_device_assignment
aval = core.raise_to_shaped(core.get_aval(x.data))
results.append(pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]))
return results
def foo_array_constant_handler(x):
return array._array_mlir_constant_handler(x.data)