mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
GSDA callback implementation for from_batched_callback
and from_batched_callback_with_devices
.
Also implemented `global_shards()` and made `Data` field optional in the `Shard` class. PiperOrigin-RevId: 407756749
This commit is contained in:
parent
8757093315
commit
55ec94766f
@ -13,10 +13,12 @@
|
||||
# limitations under the License.
|
||||
"""Implementation of GlobalShardedDeviceArray."""
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict
|
||||
from .. import core
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from ..interpreters import pxla
|
||||
from .._src.util import prod, safe_zip
|
||||
@ -33,7 +35,7 @@ Index = Tuple[slice, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _HashableSlice:
|
||||
class _HashableIndex:
|
||||
val: Index
|
||||
|
||||
def __hash__(self):
|
||||
@ -43,8 +45,8 @@ class _HashableSlice:
|
||||
return self.val == other.val
|
||||
|
||||
|
||||
def shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
|
||||
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
else:
|
||||
@ -61,10 +63,12 @@ def shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
for idx in index:
|
||||
assert isinstance(idx, slice)
|
||||
# The type: ignore is to ignore the type returned by `spec_to_indices`.
|
||||
return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
|
||||
return dict(
|
||||
(d, i)
|
||||
for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
|
||||
|
||||
|
||||
def shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
|
||||
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
|
||||
chunk_size = []
|
||||
for mesh_axis, size in zip(mesh_axes, global_shape):
|
||||
if not mesh_axis:
|
||||
@ -84,25 +88,23 @@ class Shard:
|
||||
device: Device
|
||||
index: Index
|
||||
replica_id: int
|
||||
data: DeviceArray
|
||||
# None if this `Shard` lives on a non-local device.
|
||||
data: Optional[DeviceArray] = None
|
||||
|
||||
|
||||
class GlobalShardedDeviceArray:
|
||||
|
||||
def __init__(self,
|
||||
global_shape: Shape,
|
||||
dtype,
|
||||
global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes,
|
||||
device_buffers: Sequence[DeviceArray]):
|
||||
def __init__(self, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray]):
|
||||
self._global_shape = global_shape
|
||||
self._dtype = dtype
|
||||
self._global_mesh = global_mesh
|
||||
self._mesh_axes = mesh_axes
|
||||
assert len(device_buffers) == len(self._global_mesh.local_devices)
|
||||
self._local_shards = self._create_local_shards(device_buffers)
|
||||
self._global_shards, self._local_shards = self._create_shards(
|
||||
device_buffers)
|
||||
|
||||
ss = shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
|
||||
assert all(db.shape == ss for db in device_buffers), (
|
||||
f"Expected shard shape {ss} doesn't match the device buffer "
|
||||
f"shape {device_buffers[0].shape}")
|
||||
@ -111,39 +113,39 @@ class GlobalShardedDeviceArray:
|
||||
def shape(self) -> Shape:
|
||||
return self._global_shape
|
||||
|
||||
# TODO(yashkatariya): Make this `create_shards` and create global_shards
|
||||
# Then source the local_shards and add the data to it.
|
||||
def _create_local_shards(
|
||||
self, device_buffers: Sequence[DeviceArray]) -> Sequence[Shard]:
|
||||
indices = shard_indices(self._global_shape, self._global_mesh,
|
||||
self._mesh_axes)
|
||||
|
||||
device_to_replica = {}
|
||||
index_to_replica = {}
|
||||
def _create_shards(
|
||||
self, device_buffers: Sequence[DeviceArray]
|
||||
) -> Tuple[Sequence[Shard], Sequence[Shard]]:
|
||||
indices = get_shard_indices(self._global_shape, self._global_mesh,
|
||||
self._mesh_axes)
|
||||
device_to_buffer = dict((db.device(), db) for db in device_buffers)
|
||||
gs, ls = [], []
|
||||
index_to_replica: Dict[_HashableIndex, int] = Counter()
|
||||
for device, index in indices.items():
|
||||
h_index = _HashableSlice(index)
|
||||
if h_index not in index_to_replica:
|
||||
index_to_replica[h_index] = 0
|
||||
else:
|
||||
index_to_replica[h_index] += 1
|
||||
device_to_replica[device] = index_to_replica[h_index]
|
||||
|
||||
shards = []
|
||||
# device_buffers are always local to the process.
|
||||
for db in device_buffers:
|
||||
d = db.device()
|
||||
shards.append(Shard(d, indices[d], device_to_replica[d], db))
|
||||
return shards
|
||||
h_index = _HashableIndex(index)
|
||||
replica_id = index_to_replica[h_index]
|
||||
index_to_replica[h_index] += 1
|
||||
local_shard = device.process_index == xb.process_index()
|
||||
buf = device_to_buffer[device] if local_shard else None
|
||||
sh = Shard(device, index, replica_id, buf)
|
||||
gs.append(sh)
|
||||
if local_shard:
|
||||
ls.append(sh)
|
||||
return gs, ls
|
||||
|
||||
@property
|
||||
def local_shards(self) -> Sequence[Shard]:
|
||||
return self._local_shards
|
||||
|
||||
@property
|
||||
def global_shards(self) -> Sequence[Shard]:
|
||||
return self._global_shards
|
||||
|
||||
@classmethod
|
||||
def from_callback(cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes,
|
||||
data_callback: Callable[[Index], ArrayLike]):
|
||||
indices = shard_indices(global_shape, global_mesh, mesh_axes)
|
||||
mesh_axes: MeshAxes, data_callback: Callable[[Index],
|
||||
ArrayLike]):
|
||||
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
|
||||
dbs = [
|
||||
device_put(data_callback(indices[device]), device)
|
||||
for device in global_mesh.local_devices
|
||||
@ -151,14 +153,31 @@ class GlobalShardedDeviceArray:
|
||||
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
|
||||
|
||||
@classmethod
|
||||
def from_batched_callback(
|
||||
cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes, data_callback: Callable[[Sequence[Index]], Sequence[ArrayLike]]):
|
||||
raise NotImplementedError("Not implemented yet.")
|
||||
def from_batched_callback(cls, global_shape: Shape, dtype,
|
||||
global_mesh: pxla.Mesh, mesh_axes: MeshAxes,
|
||||
data_callback: Callable[[Sequence[Index]],
|
||||
Sequence[ArrayLike]]):
|
||||
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
|
||||
local_indices = [indices[d] for d in global_mesh.local_devices]
|
||||
local_arrays = data_callback(local_indices)
|
||||
dbs = pxla.device_put(local_arrays, global_mesh.local_devices)
|
||||
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
|
||||
|
||||
@classmethod
|
||||
def from_batched_callback_with_devices(
|
||||
cls, global_shape: Shape, dtype, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes,
|
||||
data_callback: Callable[[Sequence[Tuple[Index, Tuple[Device]]]], Sequence[DeviceArray]]):
|
||||
raise NotImplementedError("Not implemented yet.")
|
||||
data_callback: Callable[[List[Tuple[Index, Tuple[Device, ...]]]],
|
||||
Sequence[DeviceArray]]):
|
||||
indices = get_shard_indices(global_shape, global_mesh, mesh_axes)
|
||||
|
||||
index_to_device: Dict[_HashableIndex, List[Device]] = defaultdict(list)
|
||||
for device in global_mesh.local_devices:
|
||||
h_index = _HashableIndex(indices[device])
|
||||
index_to_device[h_index].append(device)
|
||||
|
||||
cb_inp = [
|
||||
(index.val, tuple(device)) for index, device in index_to_device.items()
|
||||
]
|
||||
dbs = data_callback(cb_inp)
|
||||
return cls(global_shape, dtype, global_mesh, mesh_axes, dbs)
|
||||
|
@ -21,7 +21,7 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import prod
|
||||
from jax._src.util import prod, safe_zip
|
||||
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.maps import Mesh
|
||||
@ -92,6 +92,11 @@ class GSDATest(jtu.JaxTestCase):
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
self.assertListEqual([i.device.id for i in gsda.local_shards],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7])
|
||||
for g, l in safe_zip(gsda.global_shards, gsda.local_shards):
|
||||
self.assertEqual(g.device, l.device)
|
||||
self.assertEqual(g.index, l.index)
|
||||
self.assertEqual(g.replica_id, l.replica_id)
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y_z", ["x", "y", "z"],
|
||||
@ -191,7 +196,57 @@ class GSDATest(jtu.JaxTestCase):
|
||||
self.assertEqual(gsda.local_shards[0].data.shape, expected_shard_shape)
|
||||
replica_ids = [i.replica_id for i in gsda.local_shards]
|
||||
self.assertListEqual(replica_ids, expected_replica_ids)
|
||||
for g, l in safe_zip(gsda.global_shards, gsda.local_shards):
|
||||
self.assertEqual(g.device, l.device)
|
||||
self.assertEqual(g.index, l.index)
|
||||
self.assertEqual(g.replica_id, l.replica_id)
|
||||
self.assertArraysEqual(g.data, l.data)
|
||||
|
||||
def test_gsda_batched_callback(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = [('x', 'y')]
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(indices):
|
||||
self.assertEqual(len(indices), len(global_mesh.local_devices))
|
||||
return [global_input_data[index] for index in indices]
|
||||
|
||||
gsda = GlobalShardedDeviceArray.from_batched_callback(
|
||||
global_input_shape, jnp.float32, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1]])
|
||||
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[2, 3]])
|
||||
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
|
||||
expected_second_shard_value)
|
||||
|
||||
def test_gsda_batched_callback_with_devices(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x']
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(cb_inp):
|
||||
self.assertLen(cb_inp, 4)
|
||||
dbs = []
|
||||
for inp in cb_inp:
|
||||
index, devices = inp
|
||||
self.assertLen(devices, 2)
|
||||
array = global_input_data[index]
|
||||
dbs.extend([jax.device_put(array, device) for device in devices])
|
||||
return dbs
|
||||
|
||||
gsda = GlobalShardedDeviceArray.from_batched_callback_with_devices(
|
||||
global_input_shape, jnp.float32, global_mesh, mesh_axes, cb)
|
||||
expected_first_shard_value = np.array([[0, 1], [2, 3]])
|
||||
self.assertArraysEqual(gsda.local_shards[0].data.to_py(),
|
||||
expected_first_shard_value)
|
||||
expected_second_shard_value = np.array([[0, 1], [2, 3]])
|
||||
self.assertArraysEqual(gsda.local_shards[1].data.to_py(),
|
||||
expected_second_shard_value)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user