Make device_put_sharded and device_put_replicated return Arrays.

PiperOrigin-RevId: 456525113
This commit is contained in:
Yash Katariya 2022-06-22 08:50:54 -07:00 committed by jax authors
parent 6f3b3ac8f9
commit dce8f64b40
3 changed files with 87 additions and 32 deletions

View File

@ -2794,7 +2794,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
if not isinstance(shards, Sequence):
raise ValueError("device_put_sharded `shards` input must be a sequence; "
f"got {type(shards)}")
if not len(shards) == len(devices):
if len(shards) != len(devices):
raise ValueError(f"len(shards) = {len(shards)} must equal "
f"len(devices) = {len(devices)}.")
@ -2808,7 +2808,15 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
buffers = [buf for x, d in zip(xs, devices)
for buf in dispatch.device_put(x, d)]
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
if config.jax_array:
from jax.experimental import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
return array.Array(
stacked_aval.shape,
sharding.PmapSharding(np.array(devices), sharding_spec),
buffers, committed=True)
else:
return pxla.make_sharded_device_array(stacked_aval, None, buffers)
with config_explicit_device_put_scope():
return tree_map(_device_put_sharded, *shards)
@ -2855,7 +2863,14 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
len(xla.aval_to_xla_shapes(aval)) == 1)
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
if config.jax_array:
from jax.experimental import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(aval)
return array.Array(
aval.shape, sharding.PmapSharding(np.array(devices), sharding_spec),
[buf, *rest_bufs], committed=True)
else:
return pxla.make_sharded_device_array(aval, None, [buf, *rest_bufs])
with config_explicit_device_put_scope():
return tree_map(_device_put_replicated, x)

View File

@ -530,6 +530,19 @@ global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRe
_USE_CPP_SDA = True
def _create_pmap_sharding_spec(aval, sharded_dim=0):
if sharded_dim is not None:
sharded_aval = aval.update(
shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:])
aval_shape = aval.shape[sharded_dim]
else:
sharded_aval = aval
aval_shape = aval.shape[0]
return _pmap_sharding_spec(aval_shape, aval_shape, 1, None,
sharded_aval, sharded_dim)
def make_sharded_device_array(
aval: ShapedArray,
sharding_spec: Optional[ShardingSpec],
@ -552,9 +565,7 @@ def make_sharded_device_array(
indices: For caching purposes, will be computed if `None`.
"""
if sharding_spec is None:
sharded_aval = aval.update(shape=aval.shape[1:])
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], 1, None,
sharded_aval, 0)
sharding_spec = _create_pmap_sharding_spec(aval)
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)

View File

@ -2537,52 +2537,81 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
self.assertIsInstance(sharded_x[i], device_array.DeviceArray)
self.assertIsNone(sharded_x._npy_value)
def test_device_put_sharded_array(self):
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
)
def test_device_put_sharded(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
n_devices = len(devices)
x = [np.arange(i, i + 4) for i in range(n_devices)]
y = jax.device_put_sharded(x, devices)
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertEqual(len(y.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
with jax._src.config.jax_array(is_jax_array):
y = jax.device_put_sharded(x, devices)
self.assertIsInstance(y, array_type)
buffers = getattr(y, buffer_attr)
self.assertEqual(len(buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
self.assertArraysEqual(y, jnp.stack(x))
def test_device_put_sharded_pytree(self):
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
)
def test_device_put_sharded_pytree(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
n_devices = len(devices)
x = [(i, np.arange(i, i + 4)) for i in range(n_devices)]
y1, y2 = jax.device_put_sharded(x, devices)
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
with jax._src.config.jax_array(is_jax_array):
y1, y2 = jax.device_put_sharded(x, devices)
def test_device_put_replicated_array(self):
self.assertIsInstance(y1, array_type)
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
y1_buffers = getattr(y1, buffer_attr)
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
self.assertIsInstance(y2, array_type)
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
y2_buffers = getattr(y2, buffer_attr)
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
)
def test_device_put_replicated(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
x = np.arange(1, 5)
y = jax.device_put_replicated(x, devices)
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertEqual(len(y.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
with jax._src.config.jax_array(is_jax_array):
y = jax.device_put_replicated(x, devices)
self.assertIsInstance(y, array_type)
buffers = getattr(y, buffer_attr)
self.assertEqual(len(buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(buffers, devices)))
self.assertArraysEqual(y, np.stack([x for _ in devices]))
def test_device_put_replicated_pytree(self):
@parameterized.named_parameters(
('sda', False, pxla.ShardedDeviceArray, 'device_buffers'),
('array', True, array.Array, '_arrays')
)
def test_device_put_replicated_pytree(self, is_jax_array, array_type, buffer_attr):
devices = jax.local_devices()
xs = {'a': np.arange(1, 5), 'b': np.arange(3)}
ys = jax.device_put_replicated(xs, devices)
with jax._src.config.jax_array(is_jax_array):
ys = jax.device_put_replicated(xs, devices)
self.assertIsInstance(ys, dict)
y1, y2 = ys['a'], ys['b']
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
self.assertEqual(len(y1.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
self.assertIsInstance(y1, array_type)
y1_buffers = getattr(y1, buffer_attr)
self.assertEqual(len(y1_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y1_buffers, devices)))
self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices]))
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
self.assertEqual(len(y2.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
self.assertIsInstance(y2, array_type)
y2_buffers = getattr(y2, buffer_attr)
self.assertEqual(len(y2_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y2_buffers, devices)))
self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices]))
def test_repr(self):