mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Optimizations for GDA to make creating GDA faster.
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) ) * Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow. * Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8) ``` name old time/op new time/op delta gda_construction_callback_(4, 2)_['x', 'y'] 4.77ms ± 5% 4.74ms ± 5% ~ (p=0.316 n=14+17) gda_construction_raw_(256, 8)_['x', 'y'] 17.9ms ± 5% 9.0ms ± 2% -49.92% (p=0.008 n=5+5) indices_replica_id_calc_(256, 8)_['x', 'y'] 11.4ms ± 2% 2.9ms ± 2% -74.52% (p=0.008 n=5+5) gda_construction_callback_(4, 2)_[None] 34.0ms ±20% 30.5ms ± 2% ~ (p=0.413 n=5+4) gda_construction_raw_(256, 8)_[None] 15.9ms ± 2% 7.7ms ± 3% -51.56% (p=0.008 n=5+5) indices_replica_id_calc_(256, 8)_[None] 9.39ms ± 3% 1.74ms ± 2% -81.44% (p=0.008 n=5+5) gda_construction_callback_(4, 2)_['x'] 8.87ms ± 2% 8.92ms ± 3% ~ (p=0.841 n=5+5) gda_construction_raw_(256, 8)_['x'] 16.4ms ± 2% 7.7ms ± 1% -52.66% (p=0.008 n=5+5) indices_replica_id_calc_(256, 8)_['x'] 9.85ms ± 1% 1.90ms ± 2% -80.68% (p=0.008 n=5+5) gda_construction_callback_(4, 2)_['y'] 15.9ms ± 3% 16.0ms ± 5% ~ (p=0.690 n=5+5) gda_construction_raw_(256, 8)_['y'] 15.8ms ± 3% 7.6ms ± 1% -52.04% (p=0.008 n=5+5) indices_replica_id_calc_(256, 8)_['y'] 9.29ms ± 1% 1.78ms ± 1% -80.79% (p=0.008 n=5+5) gda_construction_callback_(4, 2)_[('x', 'y')] 4.65ms ± 2% 4.62ms ± 3% ~ (p=0.440 n=5+10) gda_construction_raw_(256, 8)_[('x', 'y')] 18.6ms ± 3% 9.7ms ± 5% -47.76% (p=0.008 n=5+5) indices_replica_id_calc_(256, 8)_[('x', 'y')] 11.8ms ± 4% 3.5ms ± 2% -70.28% (p=0.008 n=5+5) gda_construction_raw_(128, 8)_['x', 'y'] 8.54ms ± 1% 4.03ms ± 2% -52.84% (p=0.008 n=5+5) indices_replica_id_calc_(128, 8)_['x', 'y'] 5.40ms ± 4% 1.10ms ± 1% -79.69% (p=0.008 n=5+5) gda_construction_raw_(4, 2)_['x', 'y'] 173µs ± 1% 193µs ± 3% +11.63% (p=0.008 n=5+5) indices_replica_id_calc_(4, 2)_['x', 'y'] 127µs ± 1% 147µs ± 1% +15.57% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 421623147
This commit is contained in:
parent
2e4687a62e
commit
0532a63261
@ -30,6 +30,8 @@ mesh_shapes_axes = [
|
||||
((256, 8), [("x", "y")]),
|
||||
((128, 8), ["x", "y"]),
|
||||
((4, 2), ["x", "y"]),
|
||||
((16, 4), ["x", "y"]),
|
||||
((16, 4), [("x", "y")]),
|
||||
]
|
||||
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
"""Implementation of GlobalDeviceArray."""
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
|
||||
@ -25,6 +24,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import pxla, xla
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.api import device_put
|
||||
from jax.tree_util import tree_flatten
|
||||
from jax.interpreters.sharded_jit import PartitionSpec
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
@ -35,26 +35,19 @@ ArrayLike = Union[np.ndarray, DeviceArray]
|
||||
Index = Tuple[slice, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _HashableIndex:
|
||||
val: Index
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple((v.start, v.stop, v.step) for v in self.val))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.val == other.val
|
||||
|
||||
def _canonicalize_mesh_axes(mesh_axes):
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
else:
|
||||
pspec = mesh_axes
|
||||
return pspec
|
||||
|
||||
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Tuple[pxla.Index, ...]:
|
||||
# Import here to avoid cyclic import error when importing gda in pjit.py.
|
||||
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
|
||||
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
else:
|
||||
pspec = mesh_axes
|
||||
pspec = _canonicalize_mesh_axes(mesh_axes)
|
||||
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes")
|
||||
array_mapping = get_array_mapping(parsed_pspec)
|
||||
# The dtype doesn't matter for creating sharding specs.
|
||||
@ -78,18 +71,37 @@ def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
|
||||
|
||||
|
||||
def _calc_replica_ids(global_mesh: pxla.Mesh, mesh_axes: MeshAxes):
|
||||
pspec = _canonicalize_mesh_axes(mesh_axes)
|
||||
mesh_values = list(global_mesh.shape.values())
|
||||
flattened_pspec, _ = tree_flatten(tuple(pspec))
|
||||
# Get the location (coordinates) of each device in the device mesh.
|
||||
device_location = np.array(np.unravel_index(
|
||||
[d.id for d in global_mesh.devices.flat], mesh_values))
|
||||
# Find all the axes that were replicated.
|
||||
# If mesh_axes = (('x', 'y'), None, 'z') and ('x', 'y', 'z') were the mesh's
|
||||
# axis, then replicated axes will be None since all axes are being used to
|
||||
# shard the input.
|
||||
replicated_axis = np.isin(list(global_mesh.shape.keys()), flattened_pspec,
|
||||
invert=True)
|
||||
# If all elements in replicated_axis are False then the input is fully sharded
|
||||
# so replica ids should be all 0s.
|
||||
if not any(replicated_axis):
|
||||
return [0] * global_mesh.devices.size
|
||||
else:
|
||||
# Drop all the sharded axes and find the location of coordinates in a linear
|
||||
# array.
|
||||
return np.ravel_multi_index(device_location[replicated_axis],
|
||||
np.array(mesh_values)[replicated_axis])
|
||||
|
||||
|
||||
def get_shard_indices_replica_ids(
|
||||
global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]:
|
||||
indices = _get_indices(global_shape, global_mesh, mesh_axes)
|
||||
index_to_replica: Dict[_HashableIndex, int] = Counter()
|
||||
out = {}
|
||||
for device, index in safe_zip(global_mesh.devices.flat, indices):
|
||||
h_index = _HashableIndex(index)
|
||||
replica_id = index_to_replica[h_index]
|
||||
index_to_replica[h_index] += 1
|
||||
out[device] = (index, replica_id)
|
||||
return out
|
||||
replica_ids = _calc_replica_ids(global_mesh, mesh_axes)
|
||||
return dict((d, (i, r))
|
||||
for d, i, r in safe_zip(global_mesh.devices.flat, indices, replica_ids))
|
||||
|
||||
|
||||
def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
|
||||
@ -107,6 +119,9 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
|
||||
return tuple(chunk_size)
|
||||
|
||||
|
||||
_hashed_index = lambda x: hash(tuple((v.start, v.stop) for v in x))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Shard:
|
||||
"""A single data shard of a GlobalDeviceArray.
|
||||
@ -243,13 +258,14 @@ class GlobalDeviceArray:
|
||||
# Optionally precomputed for performance.
|
||||
self._gda_fast_path_args = _gda_fast_path_args
|
||||
self._current_process = xb.process_index()
|
||||
self._local_shards = self._create_local_shards()
|
||||
|
||||
if self._gda_fast_path_args is None:
|
||||
local_devices = self._global_mesh.local_devices
|
||||
self._local_devices = self._global_mesh.local_devices
|
||||
else:
|
||||
local_devices = self._gda_fast_path_args.local_devices
|
||||
assert len(device_buffers) == len(local_devices)
|
||||
self._local_devices = self._gda_fast_path_args.local_devices
|
||||
assert len(device_buffers) == len(self._local_devices)
|
||||
|
||||
self._local_shards = self._create_local_shards()
|
||||
|
||||
ss = get_shard_shape(self._global_shape, self._global_mesh, self._mesh_axes)
|
||||
assert all(db.shape == ss for db in device_buffers), (
|
||||
@ -285,7 +301,7 @@ class GlobalDeviceArray:
|
||||
global_indices_rid = get_shard_indices_replica_ids(
|
||||
self._global_shape, self._global_mesh, self._mesh_axes)
|
||||
local_idx_rid = dict((d, global_indices_rid[d])
|
||||
for d in self._global_mesh.local_devices)
|
||||
for d in self._local_devices)
|
||||
device_to_buffer = dict((db.device(), db) for db in self._device_buffers)
|
||||
return [
|
||||
Shard(d, index, rid, device_to_buffer[d])
|
||||
@ -373,13 +389,17 @@ class GlobalDeviceArray:
|
||||
global_shape, global_mesh, mesh_axes)
|
||||
local_devices = global_mesh.local_devices
|
||||
|
||||
index_to_device: Mapping[_HashableIndex, List[Device]] = defaultdict(list)
|
||||
index_to_device: Dict[int, Tuple[Index, List[Device]]] = {}
|
||||
for device in local_devices:
|
||||
h_index = _HashableIndex(global_indices_rid[device][0])
|
||||
index_to_device[h_index].append(device)
|
||||
index = global_indices_rid[device][0]
|
||||
h_index = _hashed_index(index)
|
||||
if h_index not in index_to_device:
|
||||
index_to_device[h_index] = (index, [device])
|
||||
else:
|
||||
index_to_device[h_index][1].append(device)
|
||||
|
||||
cb_inp = [
|
||||
(index.val, tuple(devices)) for index, devices in index_to_device.items()
|
||||
(index, tuple(devices)) for index, devices in index_to_device.values()
|
||||
]
|
||||
dbs = data_callback(cb_inp)
|
||||
local_idx_rid = dict((d, global_indices_rid[d]) for d in local_devices)
|
||||
|
@ -71,9 +71,9 @@ class GDATest(jtu.JaxTestCase):
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape,
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
@ -121,9 +121,9 @@ class GDATest(jtu.JaxTestCase):
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape,
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
@ -154,9 +154,9 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape,
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
@ -183,9 +183,9 @@ class GDATest(jtu.JaxTestCase):
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape,
|
||||
global_mesh,
|
||||
mesh_axes, cb)
|
||||
|
||||
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
|
||||
mesh_axes, cb)
|
||||
self.assertEqual(gda.local_shards[0].index, expected_index[0])
|
||||
self.assertArraysEqual(gda.local_data(0),
|
||||
global_input_data[expected_index[0]])
|
||||
|
Loading…
x
Reference in New Issue
Block a user