mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
If the bufs are on the same devices passed to batched_device_put then create an Array directly rather than going via xc.batched_device_put. Fixing the transfer guard problem should help in removing this workaround too.
PiperOrigin-RevId: 516561791
This commit is contained in:
parent
8c7ba99f82
commit
b97fb56e95
@ -3105,21 +3105,10 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
f"consistent shape and dtype, but got {a1} and {a2}.")
|
||||
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
||||
if config.jax_array:
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
bufs = [x for x, d in safe_zip(xs, devices)
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.device() == d)]
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
stacked_aval,
|
||||
PmapSharding(np.array(devices), sharding_spec),
|
||||
bufs, committed=True, _skip_checks=True)
|
||||
|
||||
xs = [xla.canonicalize_dtype(arg) for arg in xs]
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
return pxla.batched_device_put(
|
||||
stacked_aval,
|
||||
PmapSharding(np.array(devices), sharding_spec),
|
||||
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
|
||||
xs, list(devices))
|
||||
else:
|
||||
buffers = [buf for x, d in zip(xs, devices)
|
||||
|
@ -1352,17 +1352,6 @@ def _copy_device_array_to_device(
|
||||
return device_array.make_device_array(x.aval, device, moved_buf)
|
||||
|
||||
|
||||
def _copy_array_to_device(x: jax.Array, aval: core.AbstractValue,
|
||||
device: xc.Device) -> array.ArrayImpl:
|
||||
"""Copies `Array`s with SingleDeviceSharding to a different device."""
|
||||
if xb.get_device_backend(device).platform == x.device().platform:
|
||||
# source and target platforms are the same
|
||||
if x.device() == device:
|
||||
return array._single_device_array_from_buf(x, True)
|
||||
return pxla.batched_device_put( # type: ignore
|
||||
aval, SingleDeviceSharding(device), [x], [device])
|
||||
|
||||
|
||||
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
|
||||
# to check if shardings are compatible with the input.
|
||||
def _check_sharding(aval, s):
|
||||
@ -1417,7 +1406,8 @@ def _device_put_impl(
|
||||
if device is None:
|
||||
return x
|
||||
elif is_single_device_sharding(x.sharding):
|
||||
return _copy_array_to_device(x, aval, device)
|
||||
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
|
||||
[device])
|
||||
|
||||
if device_array.type_is_device_array(x):
|
||||
return _copy_device_array_to_device(x, device)
|
||||
|
@ -461,7 +461,20 @@ for t in device_array.device_array_types:
|
||||
shard_arg_handlers[t] = shard_device_array
|
||||
|
||||
|
||||
batched_device_put = xc.batched_device_put # pytype: disable=module-attr
|
||||
def batched_device_put(aval: core.AbstractValue,
|
||||
sharding: jax.sharding.Sharding, xs: Sequence[Any],
|
||||
devices: Sequence[jax.Device], committed: bool = True):
|
||||
from jax._src import array
|
||||
|
||||
bufs = [x for x, d in safe_zip(xs, devices)
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.device() == d)]
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore
|
||||
|
||||
|
||||
# NOTE(skye): we could refactor to generate _multi_slice parameters directly
|
||||
# from the input ShardingSpec, rather than the indices. However, this would
|
||||
|
@ -38,11 +38,11 @@ import jax.util
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import lax_reference
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -2985,7 +2985,8 @@ def shard_foo_array_handler(x, devices, indices, sharding):
|
||||
device, = devices
|
||||
if config.jax_array:
|
||||
aval = core.raise_to_shaped(core.get_aval(x.data))
|
||||
return dispatch._copy_array_to_device(x.data, aval, device)
|
||||
return pxla.batched_device_put(
|
||||
aval, jax.sharding.SingleDeviceSharding(device), [x.data], [device])
|
||||
bufs = dispatch._device_put_array(x.data, device)
|
||||
return bufs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user