If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.

PiperOrigin-RevId: 516561791
This commit is contained in:
Yash Katariya 2023-03-14 10:19:03 -07:00 committed by jax authors
parent 8c7ba99f82
commit b97fb56e95
4 changed files with 21 additions and 28 deletions

View File

@ -3105,21 +3105,10 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
if config.jax_array:
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.device() == d)]
if len(bufs) == len(xs):
return array.ArrayImpl(
stacked_aval,
PmapSharding(np.array(devices), sharding_spec),
bufs, committed=True, _skip_checks=True)
xs = [xla.canonicalize_dtype(arg) for arg in xs]
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
return pxla.batched_device_put(
stacked_aval,
PmapSharding(np.array(devices), sharding_spec),
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
xs, list(devices))
else:
buffers = [buf for x, d in zip(xs, devices)

View File

@ -1352,17 +1352,6 @@ def _copy_device_array_to_device(
return device_array.make_device_array(x.aval, device, moved_buf)
def _copy_array_to_device(x: jax.Array, aval: core.AbstractValue,
device: xc.Device) -> array.ArrayImpl:
"""Copies `Array`s with SingleDeviceSharding to a different device."""
if xb.get_device_backend(device).platform == x.device().platform:
# source and target platforms are the same
if x.device() == device:
return array._single_device_array_from_buf(x, True)
return pxla.batched_device_put( # type: ignore
aval, SingleDeviceSharding(device), [x], [device])
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
# to check if shardings are compatible with the input.
def _check_sharding(aval, s):
@ -1417,7 +1406,8 @@ def _device_put_impl(
if device is None:
return x
elif is_single_device_sharding(x.sharding):
return _copy_array_to_device(x, aval, device)
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
[device])
if device_array.type_is_device_array(x):
return _copy_device_array_to_device(x, device)

View File

@ -461,7 +461,20 @@ for t in device_array.device_array_types:
shard_arg_handlers[t] = shard_device_array
batched_device_put = xc.batched_device_put # pytype: disable=module-attr
def batched_device_put(aval: core.AbstractValue,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.device() == d)]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would

View File

@ -38,11 +38,11 @@ import jax.util
from jax.interpreters import xla
from jax._src.interpreters import mlir
from jax.interpreters import batching
from jax.interpreters import pxla
from jax._src import array
from jax._src.lib.mlir.dialects import hlo
from jax._src import dispatch
from jax._src import dtypes
from jax._src.interpreters import pxla
from jax._src import test_util as jtu
from jax._src import lax_reference
from jax._src.lax import lax as lax_internal
@ -2985,7 +2985,8 @@ def shard_foo_array_handler(x, devices, indices, sharding):
device, = devices
if config.jax_array:
aval = core.raise_to_shaped(core.get_aval(x.data))
return dispatch._copy_array_to_device(x.data, aval, device)
return pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])
bufs = dispatch._device_put_array(x.data, device)
return bufs