Optimizations for GDA to make creating GDA faster.

* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )

* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.

* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)

```
name                                           old time/op             new time/op             delta
gda_construction_callback_(4, 2)_['x', 'y']    4.77ms ± 5%             4.74ms ± 5%     ~           (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y']       17.9ms ± 5%              9.0ms ± 2%  -49.92%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y']    11.4ms ± 2%              2.9ms ± 2%  -74.52%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None]        34.0ms ±20%             30.5ms ± 2%     ~             (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None]           15.9ms ± 2%              7.7ms ± 3%  -51.56%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None]        9.39ms ± 3%             1.74ms ± 2%  -81.44%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x']         8.87ms ± 2%             8.92ms ± 3%     ~             (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x']            16.4ms ± 2%              7.7ms ± 1%  -52.66%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x']         9.85ms ± 1%             1.90ms ± 2%  -80.68%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y']         15.9ms ± 3%             16.0ms ± 5%     ~             (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y']            15.8ms ± 3%              7.6ms ± 1%  -52.04%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y']         9.29ms ± 1%             1.78ms ± 1%  -80.79%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')]  4.65ms ± 2%             4.62ms ± 3%     ~            (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')]     18.6ms ± 3%              9.7ms ± 5%  -47.76%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')]  11.8ms ± 4%              3.5ms ± 2%  -70.28%          (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y']       8.54ms ± 1%             4.03ms ± 2%  -52.84%          (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y']    5.40ms ± 4%             1.10ms ± 1%  -79.69%          (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y']          173µs ± 1%              193µs ± 3%  +11.63%          (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y']       127µs ± 1%              147µs ± 1%  +15.57%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 421623147
This commit is contained in:
Yash Katariya 2022-01-13 11:52:44 -08:00 committed by jax authors
parent 2e4687a62e
commit 0532a63261
3 changed files with 66 additions and 44 deletions

View File

@ -30,6 +30,8 @@ mesh_shapes_axes = [
((256, 8), [("x", "y")]),
((128, 8), ["x", "y"]),
((4, 2), ["x", "y"]),
((16, 4), ["x", "y"]),
((16, 4), [("x", "y")]),
]

View File

@ -13,7 +13,6 @@
# limitations under the License.
"""Implementation of GlobalDeviceArray."""
from collections import defaultdict, Counter
import dataclasses
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
@ -25,6 +24,7 @@ from jax._src.lib import xla_client as xc
from jax.interpreters import pxla, xla
from jax._src.util import prod, safe_zip
from jax._src.api import device_put
from jax.tree_util import tree_flatten
from jax.interpreters.sharded_jit import PartitionSpec
Shape = Tuple[int, ...]
@ -35,26 +35,19 @@ ArrayLike = Union[np.ndarray, DeviceArray]
Index = Tuple[slice, ...]
@dataclasses.dataclass(frozen=True)
class _HashableIndex:
val: Index
def __hash__(self):
return hash(tuple((v.start, v.stop, v.step) for v in self.val))
def __eq__(self, other):
return self.val == other.val
def _canonicalize_mesh_axes(mesh_axes):
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
return pspec
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
pspec = _canonicalize_mesh_axes(mesh_axes)
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
array_mapping = get_array_mapping(parsed_pspec)
# The dtype doesn't matter for creating sharding specs.
@ -78,18 +71,37 @@ def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
def _calc_replica_ids(global_mesh: pxla.Mesh, mesh_axes: MeshAxes):
pspec = _canonicalize_mesh_axes(mesh_axes)
mesh_values = list(global_mesh.shape.values())
flattened_pspec, _ = tree_flatten(tuple(pspec))
# Get the location (coordinates) of each device in the device mesh.
device_location = np.array(np.unravel_index(
[d.id for d in global_mesh.devices.flat], mesh_values))
# Find all the axes that were replicated.
# If mesh_axes = (('x', 'y'), None, 'z') and ('x', 'y', 'z') were the mesh's
# axis, then replicated axes will be None since all axes are being used to
# shard the input.
replicated_axis = np.isin(list(global_mesh.shape.keys()), flattened_pspec,
invert=True)
# If all elements in replicated_axis are False then the input is fully sharded
# so replica ids should be all 0s.
if not any(replicated_axis):
return [0] * global_mesh.devices.size
else:
# Drop all the sharded axes and find the location of coordinates in a linear
# array.
return np.ravel_multi_index(device_location[replicated_axis],
np.array(mesh_values)[replicated_axis])
def get_shard_indices_replica_ids(
global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
indices = _get_indices(global_shape, global_mesh, mesh_axes)
index_to_replica: Dict[_HashableIndex, int] = Counter()
out = {}
for device, index in safe_zip(global_mesh.devices.flat, indices):
h_index = _HashableIndex(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
out[device] = (index, replica_id)
return out
replica_ids = _calc_replica_ids(global_mesh, mesh_axes)
return dict((d, (i, r))
for d, i, r in safe_zip(global_mesh.devices.flat, indices, replica_ids))
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
@ -107,6 +119,9 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
return tuple(chunk_size)
_hashed_index = lambda x: hash(tuple((v.start, v.stop) for v in x))
@dataclasses.dataclass(frozen=True)
class Shard:
"""A single data shard of a GlobalDeviceArray.
@ -243,13 +258,14 @@ class GlobalDeviceArray:
# Optionally precomputed for performance.
self._gda_fast_path_args = _gda_fast_path_args
self._current_process = xb.process_index()
self._local_shards = self._create_local_shards()
if self._gda_fast_path_args is None:
local_devices = self._global_mesh.local_devices
self._local_devices = self._global_mesh.local_devices
else:
local_devices = self._gda_fast_path_args.local_devices
assert len(device_buffers) == len(local_devices)
self._local_devices = self._gda_fast_path_args.local_devices
assert len(device_buffers) == len(self._local_devices)
self._local_shards = self._create_local_shards()
ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
assert all(db.shape == ss for db in device_buffers), (
@ -285,7 +301,7 @@ class GlobalDeviceArray:
global_indices_rid = get_shard_indices_replica_ids(
self._global_shape, self._global_mesh, self._mesh_axes)
local_idx_rid = dict((d, global_indices_rid[d])
for d in self._global_mesh.local_devices)
for d in self._local_devices)
device_to_buffer = dict((db.device(), db) for db in self._device_buffers)
return [
Shard(d, index, rid, device_to_buffer[d])
@ -373,13 +389,17 @@ class GlobalDeviceArray:
global_shape, global_mesh, mesh_axes)
local_devices = global_mesh.local_devices
index_to_device: Mapping[_HashableIndex, List[Device]] = defaultdict(list)
index_to_device: Dict[int, Tuple[Index, List[Device]]] = {}
for device in local_devices:
h_index = _HashableIndex(global_indices_rid[device][0])
index_to_device[h_index].append(device)
index = global_indices_rid[device][0]
h_index = _hashed_index(index)
if h_index not in index_to_device:
index_to_device[h_index] = (index, [device])
else:
index_to_device[h_index][1].append(device)
cb_inp = [
(index.val, tuple(devices)) for index, devices in index_to_device.items()
(index, tuple(devices)) for index, devices in index_to_device.values()
]
dbs = data_callback(cb_inp)
local_idx_rid = dict((d, global_indices_rid[d]) for d in local_devices)

View File

@ -71,9 +71,9 @@ class GDATest(jtu.JaxTestCase):
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape,
global_mesh,
mesh_axes, cb)
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
@ -121,9 +121,9 @@ class GDATest(jtu.JaxTestCase):
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape,
global_mesh,
mesh_axes, cb)
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
@ -154,9 +154,9 @@ class GDATest(jtu.JaxTestCase):
global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape,
global_mesh,
mesh_axes, cb)
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
@ -183,9 +183,9 @@ class GDATest(jtu.JaxTestCase):
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape,
global_mesh,
mesh_axes, cb)
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])