mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make device_put_sharded
and device_put_replicated
return Arrays.
PiperOrigin-RevId: 456525113
This commit is contained in:
parent
6f3b3ac8f9
commit
dce8f64b40
@ -2794,7 +2794,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
if not isinstance(shards, Sequence):
|
||||
raise ValueError("device_put_sharded `shards` input must be a sequence; "
|
||||
f"got {type(shards)}")
|
||||
if not len(shards) == len(devices):
|
||||
if len(shards) != len(devices):
|
||||
raise ValueError(f"len(shards) = {len(shards)} must equal "
|
||||
f"len(devices) = {len(devices)}.")
|
||||
|
||||
@ -2808,7 +2808,15 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
||||
buffers = [buf for x, d in zip(xs, devices)
|
||||
for buf in dispatch.device_put(x, d)]
|
||||
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
||||
if config.jax_array:
|
||||
from jax.experimental import array, sharding
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
return array.Array(
|
||||
stacked_aval.shape,
|
||||
sharding.PmapSharding(np.array(devices), sharding_spec),
|
||||
buffers, committed=True)
|
||||
else:
|
||||
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
return tree_map(_device_put_sharded, *shards)
|
||||
@ -2855,7 +2863,14 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
len(xla.aval_to_xla_shapes(aval)) == 1)
|
||||
buf, = dispatch.device_put(x, devices[0])
|
||||
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
|
||||
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
|
||||
if config.jax_array:
|
||||
from jax.experimental import array, sharding
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||
return array.Array(
|
||||
aval.shape, sharding.PmapSharding(np.array(devices), sharding_spec),
|
||||
[buf, *rest_bufs], committed=True)
|
||||
else:
|
||||
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
|
||||
|
||||
with config_explicit_device_put_scope():
|
||||
return tree_map(_device_put_replicated, x)
|
||||
|
@ -530,6 +530,19 @@ global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRe
|
||||
_USE_CPP_SDA = True
|
||||
|
||||
|
||||
def _create_pmap_sharding_spec(aval, sharded_dim=0):
|
||||
if sharded_dim is not None:
|
||||
sharded_aval = aval.update(
|
||||
shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:])
|
||||
aval_shape = aval.shape[sharded_dim]
|
||||
else:
|
||||
sharded_aval = aval
|
||||
aval_shape = aval.shape[0]
|
||||
|
||||
return _pmap_sharding_spec(aval_shape, aval_shape, 1, None,
|
||||
sharded_aval, sharded_dim)
|
||||
|
||||
|
||||
def make_sharded_device_array(
|
||||
aval: ShapedArray,
|
||||
sharding_spec: Optional[ShardingSpec],
|
||||
@ -552,9 +565,7 @@ def make_sharded_device_array(
|
||||
indices: For caching purposes, will be computed if `None`.
|
||||
"""
|
||||
if sharding_spec is None:
|
||||
sharded_aval = aval.update(shape=aval.shape[1:])
|
||||
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], 1, None,
|
||||
sharded_aval, 0)
|
||||
sharding_spec = _create_pmap_sharding_spec(aval)
|
||||
|
||||
if indices is None:
|
||||
indices = spec_to_indices(aval.shape, sharding_spec)
|
||||
|
@ -2537,52 +2537,81 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(sharded_x[i], device_array.DeviceArray)
|
||||
self.assertIsNone(sharded_x._npy_value)
|
||||
|
||||
def test_device_put_sharded_array(self):
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
)
|
||||
def test_device_put_sharded(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
n_devices = len(devices)
|
||||
x = [np.arange(i, i + 4) for i in range(n_devices)]
|
||||
y = jax.device_put_sharded(x, devices)
|
||||
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
||||
self.assertEqual(len(y.device_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
|
||||
with jax._src.config.jax_array(is_jax_array):
|
||||
y = jax.device_put_sharded(x, devices)
|
||||
self.assertIsInstance(y, array_type)
|
||||
buffers = getattr(y, buffer_attr)
|
||||
self.assertEqual(len(buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
|
||||
self.assertArraysEqual(y, jnp.stack(x))
|
||||
|
||||
def test_device_put_sharded_pytree(self):
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
)
|
||||
def test_device_put_sharded_pytree(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
n_devices = len(devices)
|
||||
x = [(i, np.arange(i, i + 4)) for i in range(n_devices)]
|
||||
y1, y2 = jax.device_put_sharded(x, devices)
|
||||
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
|
||||
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
|
||||
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
|
||||
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
|
||||
with jax._src.config.jax_array(is_jax_array):
|
||||
y1, y2 = jax.device_put_sharded(x, devices)
|
||||
|
||||
def test_device_put_replicated_array(self):
|
||||
self.assertIsInstance(y1, array_type)
|
||||
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
|
||||
y1_buffers = getattr(y1, buffer_attr)
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
|
||||
|
||||
self.assertIsInstance(y2, array_type)
|
||||
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
|
||||
y2_buffers = getattr(y2, buffer_attr)
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
)
|
||||
def test_device_put_replicated(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
x = np.arange(1, 5)
|
||||
y = jax.device_put_replicated(x, devices)
|
||||
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
||||
self.assertEqual(len(y.device_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
|
||||
with jax._src.config.jax_array(is_jax_array):
|
||||
y = jax.device_put_replicated(x, devices)
|
||||
|
||||
self.assertIsInstance(y, array_type)
|
||||
buffers = getattr(y, buffer_attr)
|
||||
self.assertEqual(len(buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
|
||||
self.assertArraysEqual(y, np.stack([x for _ in devices]))
|
||||
|
||||
def test_device_put_replicated_pytree(self):
|
||||
@parameterized.named_parameters(
|
||||
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
|
||||
('array', True, array.Array, '_arrays')
|
||||
)
|
||||
def test_device_put_replicated_pytree(self, is_jax_array, array_type, buffer_attr):
|
||||
devices = jax.local_devices()
|
||||
xs = {'a': np.arange(1, 5), 'b': np.arange(3)}
|
||||
ys = jax.device_put_replicated(xs, devices)
|
||||
with jax._src.config.jax_array(is_jax_array):
|
||||
ys = jax.device_put_replicated(xs, devices)
|
||||
self.assertIsInstance(ys, dict)
|
||||
y1, y2 = ys['a'], ys['b']
|
||||
|
||||
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
|
||||
self.assertEqual(len(y1.device_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
|
||||
self.assertIsInstance(y1, array_type)
|
||||
y1_buffers = getattr(y1, buffer_attr)
|
||||
self.assertEqual(len(y1_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
|
||||
self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices]))
|
||||
|
||||
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
|
||||
self.assertEqual(len(y2.device_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
|
||||
self.assertIsInstance(y2, array_type)
|
||||
y2_buffers = getattr(y2, buffer_attr)
|
||||
self.assertEqual(len(y2_buffers), len(devices))
|
||||
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
|
||||
self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices]))
|
||||
|
||||
def test_repr(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user