Rollback the device_put_sharded and device_put_replicated change of using batched_device_put

PiperOrigin-RevId: 515926020
This commit is contained in:
Yash Katariya 2023-03-11 15:31:15 -08:00 committed by jax authors
parent 96da1c4b71
commit 3375f011bb

View File

@ -3110,16 +3110,15 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
buffers = [buf for x, d in zip(xs, devices)
for buf in dispatch.device_put(x, d)]
if config.jax_array:
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
xs = [xla.canonicalize_dtype(arg) for arg in xs]
return pxla.batched_device_put(
return array.ArrayImpl(
stacked_aval,
PmapSharding(np.array(devices), sharding_spec),
xs, devices)
buffers, committed=True, _skip_checks=True)
else:
buffers = [buf for x, d in zip(xs, devices)
for buf in dispatch.device_put(x, d)]
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
with config_explicit_device_put_scope():
@ -3165,16 +3164,14 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
core.raise_to_shaped(core.get_aval(x)))
assert (isinstance(aval, ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1)
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
if config.jax_array:
sharding_spec = pxla._create_pmap_sharding_spec(aval)
x = xla.canonicalize_dtype(x)
buf = jax.device_put(x, devices[0])
return pxla.batched_device_put(
return array.ArrayImpl(
aval, PmapSharding(np.array(devices), sharding_spec),
[buf] * len(devices), devices)
[buf, *rest_bufs], committed=True, _skip_checks=True)
else:
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
with config_explicit_device_put_scope():