GSDA callback implementation for from_batched_callback and from_batched_callback_with_devices.

Also implemented `global_shards()` and made `Data` field optional in the `Shard` class.

PiperOrigin-RevId: 407756749
This commit is contained in:
Yash Katariya 2021-11-05 00:15:03 -07:00 committed by jax authors
parent 8757093315
commit 55ec94766f
2 changed files with 121 additions and 47 deletions

View File

@ -13,10 +13,12 @@
# limitations under the License.
"""Implementation of GlobalShardedDeviceArray."""
from collections import defaultdict, Counter
import dataclasses
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict
from .. import core
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from ..interpreters import pxla
from .._src.util import prod, safe_zip
@ -33,7 +35,7 @@ Index = Tuple[slice, ...]
@dataclasses.dataclass(frozen=True)
class _HashableSlice:
class _HashableIndex:
val: Index
def __hash__(self):
@ -43,8 +45,8 @@ class _HashableSlice:
return self.val == other.val
def shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
@ -61,10 +63,12 @@ def shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
for idx in index:
assert isinstance(idx, slice)
# The type: ignore is to ignore the type returned by `spec_to_indices`.
return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
return dict(
(d, i)
for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
def shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
chunk_size = []
for mesh_axis, size in zip(mesh_axes, global_shape):
if not mesh_axis:
@ -84,25 +88,23 @@ class Shard:
device: Device
index: Index
replica_id: int
data: DeviceArray
# None if this `Shard` lives on a non-local device.
data: Optional[DeviceArray] = None
class GlobalShardedDeviceArray:
def __init__(self,
global_shape: Shape,
dtype,
global_mesh: pxla.Mesh,
mesh_axes: MeshAxes,
device_buffers: Sequence[DeviceArray]):
def __init__(self, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray]):
self._global_shape = global_shape
self._dtype = dtype
self._global_mesh = global_mesh
self._mesh_axes = mesh_axes
assert len(device_buffers) == len(self._global_mesh.local_devices)
self._local_shards = self._create_local_shards(device_buffers)
self._global_shards, self._local_shards = self._create_shards(
device_buffers)
ss = shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
assert all(db.shape == ss for db in device_buffers), (
f"Expected shard shape {ss} doesn't match the device buffer "
f"shape {device_buffers[0].shape}")
@ -111,39 +113,39 @@ class GlobalShardedDeviceArray:
def shape(self) -> Shape:
return self._global_shape
# TODO(yashkatariya): Make this `create_shards` and create global_shards
# Then source the local_shards and add the data to it.
def _create_local_shards(
self, device_buffers: Sequence[DeviceArray]) -> Sequence[Shard]:
indices = shard_indices(self._global_shape, self._global_mesh,
self._mesh_axes)
device_to_replica = {}
index_to_replica = {}
def _create_shards(
self, device_buffers: Sequence[DeviceArray]
) -> Tuple[Sequence[Shard], Sequence[Shard]]:
indices = get_shard_indices(self._global_shape, self._global_mesh,
self._mesh_axes)
device_to_buffer = dict((db.device(), db) for db in device_buffers)
gs, ls = [], []
index_to_replica: Dict[_HashableIndex, int] = Counter()
for device, index in indices.items():
h_index = _HashableSlice(index)
if h_index not in index_to_replica:
index_to_replica[h_index] = 0
else:
index_to_replica[h_index] += 1
device_to_replica[device] = index_to_replica[h_index]
shards = []
# device_buffers are always local to the process.
for db in device_buffers:
d = db.device()
shards.append(Shard(d, indices[d], device_to_replica[d], db))
return shards
h_index = _HashableIndex(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
local_shard = device.process_index == xb.process_index()
buf = device_to_buffer[device] if local_shard else None
sh = Shard(device, index, replica_id, buf)
gs.append(sh)
if local_shard:
ls.append(sh)
return gs, ls
@property
def local_shards(self) -> Sequence[Shard]:
return self._local_shards
@property
def global_shards(self) -> Sequence[Shard]:
return self._global_shards
@classmethod
def from_callback(cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes,
data_callback: Callable[[Index], ArrayLike]):
indices = shard_indices(global_shape, global_mesh, mesh_axes)
mesh_axes: MeshAxes, data_callback: Callable[[Index],
ArrayLike]):
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
dbs = [
device_put(data_callback(indices[device]), device)
for device in global_mesh.local_devices
@ -151,14 +153,31 @@ class GlobalShardedDeviceArray:
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
@classmethod
def from_batched_callback(
cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes, data_callback: Callable[[Sequence[Index]], Sequence[ArrayLike]]):
raise NotImplementedError("Not implemented yet.")
def from_batched_callback(cls, global_shape: Shape, dtype,
global_mesh: pxla.Mesh, mesh_axes: MeshAxes,
data_callback: Callable[[Sequence[Index]],
Sequence[ArrayLike]]):
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
local_indices = [indices[d] for d in global_mesh.local_devices]
local_arrays = data_callback(local_indices)
dbs = pxla.device_put(local_arrays, global_mesh.local_devices)
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
@classmethod
def from_batched_callback_with_devices(
cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes,
data_callback: Callable[[Sequence[Tuple[Index, Tuple[Device]]]], Sequence[DeviceArray]]):
raise NotImplementedError("Not implemented yet.")
data_callback: Callable[[List[Tuple[Index, Tuple[Device, ...]]]],
Sequence[DeviceArray]]):
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
index_to_device: Dict[_HashableIndex, List[Device]] = defaultdict(list)
for device in global_mesh.local_devices:
h_index = _HashableIndex(indices[device])
index_to_device[h_index].append(device)
cb_inp = [
(index.val, tuple(device)) for index, device in index_to_device.items()
]
dbs = data_callback(cb_inp)
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)

View File

@ -21,7 +21,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.util import prod
from jax._src.util import prod, safe_zip
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import Mesh
@ -92,6 +92,11 @@ class GSDATest(jtu.JaxTestCase):
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in gsda.local_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
for g, l in safe_zip(gsda.global_shards, gsda.local_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertArraysEqual(g.data, l.data)
@parameterized.named_parameters(
("mesh_x_y_z", ["x", "y", "z"],
@ -191,7 +196,57 @@ class GSDATest(jtu.JaxTestCase):
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gsda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
for g, l in safe_zip(gsda.global_shards, gsda.local_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertArraysEqual(g.data, l.data)
def test_gsda_batched_callback(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = [('x', 'y')]
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(indices):
self.assertEqual(len(indices), len(global_mesh.local_devices))
return [global_input_data[index] for index in indices]
gsda = GlobalShardedDeviceArray.from_batched_callback(
global_input_shape, jnp.float32, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[2, 3]])
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
expected_second_shard_value)
def test_gsda_batched_callback_with_devices(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(cb_inp):
self.assertLen(cb_inp, 4)
dbs = []
for inp in cb_inp:
index, devices = inp
self.assertLen(devices, 2)
array = global_input_data[index]
dbs.extend([jax.device_put(array, device) for device in devices])
return dbs
gsda = GlobalShardedDeviceArray.from_batched_callback_with_devices(
global_input_shape, jnp.float32, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]])
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]])
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
expected_second_shard_value)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())