mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 515926020
This commit is contained in:
parent
96da1c4b71
commit
3375f011bb
@ -3110,16 +3110,15 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
raise ValueError("the shards passed to device_put_sharded must have "
|
||||
f"consistent shape and dtype, but got {a1} and {a2}.")
|
||||
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
||||
buffers = [buf for x, d in zip(xs, devices)
|
||||
for buf in dispatch.device_put(x, d)]
|
||||
if config.jax_array:
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
xs = [xla.canonicalize_dtype(arg) for arg in xs]
|
||||
return pxla.batched_device_put(
|
||||
return array.ArrayImpl(
|
||||
stacked_aval,
|
||||
PmapSharding(np.array(devices), sharding_spec),
|
||||
xs, devices)
|
||||
buffers, committed=True, _skip_checks=True)
|
||||
else:
|
||||
buffers = [buf for x, d in zip(xs, devices)
|
||||
for buf in dispatch.device_put(x, d)]
|
||||
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
@ -3165,16 +3164,14 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
core.raise_to_shaped(core.get_aval(x)))
|
||||
assert (isinstance(aval, ShapedArray) and
|
||||
len(xla.aval_to_xla_shapes(aval)) == 1)
|
||||
buf, = dispatch.device_put(x, devices[0])
|
||||
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
|
||||
if config.jax_array:
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||
x = xla.canonicalize_dtype(x)
|
||||
buf = jax.device_put(x, devices[0])
|
||||
return pxla.batched_device_put(
|
||||
return array.ArrayImpl(
|
||||
aval, PmapSharding(np.array(devices), sharding_spec),
|
||||
[buf] * len(devices), devices)
|
||||
[buf, *rest_bufs], committed=True, _skip_checks=True)
|
||||
else:
|
||||
buf, = dispatch.device_put(x, devices[0])
|
||||
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
|
||||
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
|
Loading…
x
Reference in New Issue
Block a user