From b97fb56e95811fa8c00d114ac1b9955019c83f0a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 14 Mar 2023 10:19:03 -0700 Subject: [PATCH] If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too. PiperOrigin-RevId: 516561791 --- jax/_src/api.py | 15 ++------------- jax/_src/dispatch.py | 14 ++------------ jax/_src/interpreters/pxla.py | 15 ++++++++++++++- tests/lax_test.py | 5 +++-- 4 files changed, 21 insertions(+), 28 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 44e1899c2..d1e9e6cb4 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3105,21 +3105,10 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # f"consistent shape and dtype, but got {a1} and {a2}.") stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) if config.jax_array: - sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval) - bufs = [x for x, d in safe_zip(xs, devices) - if (isinstance(x, array.ArrayImpl) and - dispatch.is_single_device_sharding(x.sharding) and - x.device() == d)] - if len(bufs) == len(xs): - return array.ArrayImpl( - stacked_aval, - PmapSharding(np.array(devices), sharding_spec), - bufs, committed=True, _skip_checks=True) - xs = [xla.canonicalize_dtype(arg) for arg in xs] + sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval) return pxla.batched_device_put( - stacked_aval, - PmapSharding(np.array(devices), sharding_spec), + stacked_aval, PmapSharding(np.array(devices), sharding_spec), xs, list(devices)) else: buffers = [buf for x, d in zip(xs, devices) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b52c88c1f..5fca4878a 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -1352,17 +1352,6 @@ def _copy_device_array_to_device( return device_array.make_device_array(x.aval, device, moved_buf) -def _copy_array_to_device(x: jax.Array, aval: core.AbstractValue, - device: xc.Device) -> array.ArrayImpl: - """Copies `Array`s with SingleDeviceSharding to a different device.""" - if xb.get_device_backend(device).platform == x.device().platform: - # source and target platforms are the same - if x.device() == device: - return array._single_device_array_from_buf(x, True) - return pxla.batched_device_put( # type: ignore - aval, SingleDeviceSharding(device), [x], [device]) - - # TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that # to check if shardings are compatible with the input. def _check_sharding(aval, s): @@ -1417,7 +1406,8 @@ def _device_put_impl( if device is None: return x elif is_single_device_sharding(x.sharding): - return _copy_array_to_device(x, aval, device) + return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], + [device]) if device_array.type_is_device_array(x): return _copy_device_array_to_device(x, device) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d567902ff..5bf1fc2b3 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -461,7 +461,20 @@ for t in device_array.device_array_types: shard_arg_handlers[t] = shard_device_array -batched_device_put = xc.batched_device_put # pytype: disable=module-attr +def batched_device_put(aval: core.AbstractValue, + sharding: jax.sharding.Sharding, xs: Sequence[Any], + devices: Sequence[jax.Device], committed: bool = True): + from jax._src import array + + bufs = [x for x, d in safe_zip(xs, devices) + if (isinstance(x, array.ArrayImpl) and + dispatch.is_single_device_sharding(x.sharding) and + x.device() == d)] + if len(bufs) == len(xs): + return array.ArrayImpl( + aval, sharding, bufs, committed=committed, _skip_checks=True) + return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore + # NOTE(skye): we could refactor to generate _multi_slice parameters directly # from the input ShardingSpec, rather than the indices. However, this would diff --git a/tests/lax_test.py b/tests/lax_test.py index 6cdf42776..c9e79fbe3 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -38,11 +38,11 @@ import jax.util from jax.interpreters import xla from jax._src.interpreters import mlir from jax.interpreters import batching -from jax.interpreters import pxla from jax._src import array from jax._src.lib.mlir.dialects import hlo from jax._src import dispatch from jax._src import dtypes +from jax._src.interpreters import pxla from jax._src import test_util as jtu from jax._src import lax_reference from jax._src.lax import lax as lax_internal @@ -2985,7 +2985,8 @@ def shard_foo_array_handler(x, devices, indices, sharding): device, = devices if config.jax_array: aval = core.raise_to_shaped(core.get_aval(x.data)) - return dispatch._copy_array_to_device(x.data, aval, device) + return pxla.batched_device_put( + aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device]) bufs = dispatch._device_put_array(x.data, device) return bufs