rocm_jax/jax/experimental/mesh_utils.py
Yash Katariya 023bc7856b Add registration handler for TPU v5e in mesh_utils.
PiperOrigin-RevId: 643092629
2024-06-13 12:52:33 -07:00

770 lines
28 KiB
Python

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for building a device mesh."""
from __future__ import annotations
import collections
from collections.abc import Sequence
import itertools
import logging
import math
from typing import Any, Callable, Generator, MutableMapping
from jax._src import xla_bridge as xb
import numpy as np
logger = logging.getLogger(__name__)
_TPU_V2 = 'TPU v2'
_TPU_V3 = 'TPU v3'
_TPU_V4 = 'TPU v4'
_TPU_V5_LITE = "TPU v5 lite"
# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's
# famous contiguous mesh trick.
#
# The trick only works for certain topologies and mesh shapes. Trivial dims of
# size 1 can be added to the shapes listed, and they are also supported.
_TRANSPOSE_TRICKS: dict[
tuple[int, ...], dict[tuple[int, ...], tuple[int, ...]]
] = {
(2, 2, 1): {
(2, 2): (0, 1, 2),
},
(2, 2, 4): {
(4, 4): (0, 1, 2),
},
(4, 4, 4): {
(16, 4): (0, 2, 1),
},
(4, 8, 8): {
(64, 4): (0, 2, 1),
(4, 64): (0, 2, 1),
},
(8, 8, 8): {
(64, 8): (0, 2, 1),
},
(8, 16, 16): {
(256, 8): (0, 2, 1),
(8, 256): (0, 2, 1),
},
}
# Physical ordering of core IDs in a tray that creates a ring
_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5)
_TRAY_2x2_RING_ORDER = (0, 1, 3, 2)
_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4)
def _tpu_v2_v3_create_device_mesh(
mesh_shape: Sequence[int],
devices: Sequence[Any],
**unused_kwargs,
) -> np.ndarray:
if len(devices) == 8:
logger.info(
'Reordering mesh to physical ring order on single-tray TPU v2/v3.'
)
device_mesh = np.asarray(devices)
device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
elif mesh_shape[-1] == 8:
device_mesh = np.asarray(devices).reshape(mesh_shape)
logger.info(
'Reordering mesh to physical ring order on each TPU v2/v3 tray.'
)
perm = np.array(_TRAY_RING_ORDER)
device_mesh = device_mesh[..., perm]
return device_mesh
else:
# TODO(skye): implement 2D mesh_shape logic here:
# https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187
# (possibly replaces above mesh_shape[-1] == 8 case)
return np.asarray(devices).reshape(mesh_shape)
def _vlc_create_device_mesh(
mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs
) -> np.ndarray | None:
"""Creates rotated pincer device assignment for selected topologies.
Args:
mesh_shape: Logical mesh shape used by the model.
devices: TPU devices.
**unused_kwargs: ...
Returns:
None or reordered devices reshaped as `mesh_shape`.
"""
max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices)
bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1
# Our ring re-ordering makes sense only if the passed-in devices are
# sequential, which may not always be the case. reversed() changes z-minor to
# x-minor.
sequential_devices = sorted(
devices,
key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0)))))
if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2
device_mesh = np.asarray(sequential_devices)
device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4
# Only uses ring order if the whole mesh is a replica group.
if max(mesh_shape) == len(devices):
device_mesh = np.asarray(sequential_devices)
device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
return None
# Registers functions to create device mesh for specific device kinds. Takes
# precedence over the more general logic in create_device_mesh(). Handler may
# return None; in that case, it will fall back to using the default logic.
device_kind_handler_dict: dict[
str,
Callable[..., np.ndarray | None],
] = {
_TPU_V2: _tpu_v2_v3_create_device_mesh,
_TPU_V3: _tpu_v2_v3_create_device_mesh,
_TPU_V5_LITE: _vlc_create_device_mesh,
}
def _create_device_mesh_for_nd_torus(
physical_mesh: np.ndarray,
mesh_shape: Sequence[int],
*,
allow_split_physical_axes: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Assigns logical parallelism axes to physical axes of an N-D torus network.
Given logical parallelism axes with sizes in `mesh_shape` and devices in an
N-dimensional torus network represented by `physical_mesh`, maps each logical
axis to one or more physical axes. Prefer to map more-performance-sensitive
logical axes to larger numbers of physical axes to maximize the bandwidth
available to them. Also prefer to assign logical axes to multiple physical
axes of the same size (e.g., a 2D square) rather than multiple physical axes
of different sizes when possible.
If allow_split_physical_axes = False (default), this routine will error out
instead of splitting a physical axis over more than one logical axis (which
would reduce total usable bandwidth).
Let's use a concrete example to explain the concepts and considerations.
As an example, suppose the logical mesh is [data, model], for data and model
parallelism respectively. Also suppose that data parallelism is less
performance sensitive than model parallelism. Consider a 3D TPU pod slice of
shape 4x4x16, represented by a physical mesh of shape (4, 4, 16).
A TPU pod slice has equal bandwidth along all axes with wraparound links, but
a 2D plane of size 4x4 may have faster XLA collective implementations than a
non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want
the more performance sensitive `model` axis to be mapped to the 4x4 XY plane.
Args:
physical_mesh: a np.ndarray of devices in the shape of the N-D torus
physical topology.
mesh_shape: shape of the logical mesh (size of the various logical
parallelism axes), with axes ordered by increasing network intensity.
allow_split_physical_axes: If True, we would split physical axes if
necessary to fit the desired mesh shape.
Returns:
An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with
each logical parallelism axis mapped to one or more physical mesh axes.
The axis assignment matrix, which is a 2-d array mapping from
(physical_axis, logical_axis) to the size assigned, with the invariant
np.prod(assignment, axis=1) = physical_mesh_shape, and
np.prod(assignment, axis=0) = mesh_shape.
"""
# Remaining physical axes to be assigned to logical axes.
assignable_physical_mesh = list(physical_mesh.shape)
# Map each logical axis to a subset of physical axes.
assignment: list[tuple[int, ...]] = [() for _ in mesh_shape]
# Assign logical axes from highest network intensity to lowest.
# `mesh_shape` is assumed to ordered by lowest network intensity first, so
# reverse it first.
for logical_axis_index, logical_axis_size in reversed(
list(enumerate(mesh_shape))
):
# Preferentially map to more physical axes first for higher bandwidth.
for num_axes in range(3, 0, -1):
# Try assign to any subset of size num_axes. Generate all candidates.
indices_and_axes = itertools.combinations(
enumerate(assignable_physical_mesh), num_axes
)
for elem in indices_and_axes:
c_indices, c_axes = zip(*elem)
# TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only
# implemented for square 2D plane. Mapping a physical axis to two
# logical axes might be slower for non-square 2D plane, e.g., map 32 to
# 4x8 or a single axis. If XLA 2D collectives support non-square plane
# soon, we can continue to preferentially map to 2D plane in general,
# otherwise, we should treat non-square 2D plane and 1D submesh equally.
if np.prod(c_axes) == logical_axis_size:
assignment[logical_axis_index] = c_indices
# Zero the assigned physical axes.
assignable_physical_mesh = [
0 if i in c_indices else v
for i, v in enumerate(assignable_physical_mesh)
]
break
if assignment[logical_axis_index]:
# We already found an assignment from one candidate above.
break
else:
# If the num_axes for loop did not break, i.e. none of the candidates work
# goto here with this while-else construct.
if logical_axis_size > 1:
if not allow_split_physical_axes:
# Although this is now implemented, there are downstream tasks
# counting on this being a NotImplementedError.
raise NotImplementedError(
'Failed to find assignment for logical_axis_index'
f' {logical_axis_index} of size {logical_axis_size} with'
f' remaining assignable mesh {assignable_physical_mesh}. The size'
' of each axis in your logical mesh must be equal to the product'
' of some subset of the physical mesh axis sizes. E.g. logical'
' mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4'
' and 16=4x4. If you want to split physical axes, set '
' allow_split_physical_axes to True.'
)
else:
# We will try finding an assignment, even if that means splitting the
# physical axes, which requires a more sophisticated implementation.
return _create_device_mesh_for_nd_torus_splitting_axes(
physical_mesh, mesh_shape
)
# Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1).
transpose: list[int] = []
assignment_array = np.ones(
[len(physical_mesh.shape), len(mesh_shape)], dtype=np.int64
)
for i, x in enumerate(assignment):
for y in x:
physical_mesh_axis = int(y)
assignment_array[physical_mesh_axis, i] = physical_mesh.shape[
physical_mesh_axis
]
transpose.append(physical_mesh_axis)
return (
physical_mesh.transpose(transpose).reshape(mesh_shape),
assignment_array,
)
def _create_device_mesh_for_nd_torus_splitting_axes(
physical_mesh: np.ndarray,
mesh_shape: Sequence[int],
) -> tuple[np.ndarray, np.ndarray]:
"""Assigns logical parallelism axes to physical axes of an N-D torus network.
This implementation allows creating meshes that requires splitting physical
axes, and thus one could produce logical mesh of any shape, as long as the
number of devices matches, e.g.,
- Creating 2x2x4 from 4x4;
- Creating 2x2x16 from 8x8;
Args:
physical_mesh: a np.ndarray of devices in the shape of the N-D torus
physical topology.
mesh_shape: shape of the logical mesh (size of the various logical
parallelism axes), with axes ordered by increasing network intensity.
Returns:
An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with
each logical parallelism axis mapped to one or more physical mesh axes.
The axis assignment matrix, which is a 2-d array mapping from
(physical_axis, logical_axis) to the size assigned, with the invariant
np.prod(assignment, axis=1) = physical_mesh_shape, and
np.prod(assignment, axis=0) = mesh_shape.
"""
if np.prod(physical_mesh.shape) != np.prod(mesh_shape):
raise ValueError(
'The number of devices in physical mesh'
f' {physical_mesh.shape} does not match the number of devices'
f' in logical mesh {mesh_shape}.'
)
physical_mesh_shape = physical_mesh.shape
logical_mesh_shape = tuple(mesh_shape)
# (Partial) assignment map as an 2-d array [p_axis, l_axis] -> size.
assignment = np.ones(
[len(physical_mesh_shape), len(logical_mesh_shape)], dtype=np.int64
)
# Process logical axes from highest network intensity to lowest.
# `mesh_shape` is assumed to ordered by lowest network intensity first, so
# reverse it.
for logical_axis, logical_axis_size in reversed(
list(enumerate(logical_mesh_shape))
):
# Go over all the possible assignment for the logical axis, including the
# one that splits multiple physical axes.
best_logical_axis_assignment = None
for logical_axis_assignment in _enumerate_feasible_logical_axis_assignments(
physical_mesh_shape, assignment, logical_axis_size
):
# TODO(rosun): Instead of using heuristics, replace this with a proper
# scoring function reflecting the underlying hardware properties.
if (
best_logical_axis_assignment is None
or _prefer_first_logical_axis_assignment(
logical_axis_assignment,
best_logical_axis_assignment,
physical_mesh_shape=physical_mesh_shape,
assignment=assignment,
)
):
best_logical_axis_assignment = logical_axis_assignment
assignment[:, logical_axis] = best_logical_axis_assignment
# Read out the assignment.
logical_mesh = _generate_logical_mesh(
physical_mesh, logical_mesh_shape, assignment
)
return logical_mesh, assignment
def _get_prime_factors(x: int) -> list[int]:
"""Returns a sorted list of prime factors for the given number."""
assert x > 0
factors = []
for p in range(2, math.isqrt(x) + 2):
while x % p == 0:
factors.append(p)
x //= p
if x == 1:
return factors
else:
return [x] # x is a prime number.
def _enumerate_feasible_logical_axis_assignments(
physical_mesh_shape: Sequence[int],
assignment: np.ndarray,
logical_axis_size: int,
) -> Generator[np.ndarray, None, None]:
"""Yields feasible assignments for a single logical axis.
For a physical mesh of shape [x_1, ..., x_n], and the product of all previous
assignments on each physical axes [y_1, ..., y_n], this function yields all
possible assignments for the axis as 1-d arrays [z_1, ..., z_n], so that:
- prod(z_1, ..., z_n) = logical_axis_size
- x_i % (z_i * y_i) = 0
Args:
physical_mesh_shape: Physical mesh shape.
assignment: Existing assignment matrix.
logical_axis_size: Size of the logical axis to assign.
Yields:
All valid assignments for the logical axis. Each assignment is represented
as an integer array of length len(physical_mesh_shape).
"""
logical_axis_factors: MutableMapping[int, int] = collections.defaultdict(int)
for factor in _get_prime_factors(logical_axis_size):
logical_axis_factors[factor] += 1
available_physical_mesh_shape = np.array(physical_mesh_shape) // np.prod(
assignment, axis=-1
)
# To enable efficient enumerations, we first index physical axes by their
# prime factors. Since we know the prime factorization of the logical axis
# size, we could simply enumerate by picking the correct count for each
# prime factor.
physical_axes_by_factor: MutableMapping[int, list[int]] = (
collections.defaultdict(list)
)
for physical_axis, physical_axis_size in enumerate(
available_physical_mesh_shape
):
for factor in _get_prime_factors(physical_axis_size):
if factor not in logical_axis_factors:
continue
physical_axes_by_factor[factor].append(physical_axis)
factors = []
assignments_by_factor = []
for factor, multiplicity in logical_axis_factors.items():
factors.append(factor)
assignments_by_factor.append(
set(
itertools.combinations(
physical_axes_by_factor[factor], multiplicity
)
)
)
for axis_assignment in itertools.product(*assignments_by_factor):
result = np.ones([len(physical_mesh_shape)], dtype=np.int64)
for factor_index, per_factor_assignment in enumerate(axis_assignment):
for physical_axis in per_factor_assignment:
result[physical_axis] *= factors[factor_index]
yield result
def _prefer_first_logical_axis_assignment(
x: np.ndarray,
y: np.ndarray,
*,
physical_mesh_shape: Sequence[int],
assignment: np.ndarray,
) -> bool:
"""Returns True if the first axis assignment is preferred over the second.
For now, this is implemented with some very simple heuristics. However,
it is possible to introduce e.g., a value function here based on a more
precise model of the underlying hardware.
TODO(rosun): Use a proxy of network capacity to select the partitions.
Args:
x: Logical axis assignment as [len(physical_mesh_shape)] array.
y: Logical axis assignment as [len(physical_mesh_shape)] array.
physical_mesh_shape: Physical mesh shape.
assignment: Assignment matrix.
Returns:
True if x is preferred over y.
"""
# Prefer occupying complete physical axes. I don't have a good reason for
# this, except that it is compatible with the existing behavior.
#
# E.g., on 4 x 4 x 8, [4, 4, -] will be preferred over [4, -, 4], and then
# over [2, 2, 4].
x_whole_axis_size = np.prod(
[s for i, s in enumerate(x) if s == physical_mesh_shape[i]]
)
y_whole_axis_size = np.prod(
[s for i, s in enumerate(y) if s == physical_mesh_shape[i]]
)
if x_whole_axis_size != y_whole_axis_size:
return x_whole_axis_size > y_whole_axis_size
# Prefer occupying more whole physical axes for better bandwidth.
#
# This is consistent with existing logic, i.e., 2 x 2 is preferred over 4.
x_num_whole_axes = len(
[1 for i, s in enumerate(x) if s == physical_mesh_shape[i] and s > 1]
)
y_num_whole_axes = len(
[1 for i, s in enumerate(y) if s == physical_mesh_shape[i] and s > 1]
)
if x_num_whole_axes != y_num_whole_axes:
return x_num_whole_axes > y_num_whole_axes
# Prefer taking physical axes that are not taken by logical axes of higher
# network intensity. E.g., for a 4 x 4 x 4, suppose that the previous
# assignments are 1 x 2 x 4, and we want to place a new logical axis of size
# 2, we will go for [2, 1, 1] instead of [1, 2, 1], as the latter choice will
# tap into bandwidth already taken by the higher intensity axis.
assigned_physical_mesh_shape = np.prod(assignment, axis=-1)
x_non_overlapping_axis_size = np.prod(
[s for i, s in enumerate(x) if assigned_physical_mesh_shape[i] > 1]
)
y_non_overlapping_axis_size = np.prod(
[s for i, s in enumerate(y) if assigned_physical_mesh_shape[i] > 1]
)
if x_non_overlapping_axis_size != y_non_overlapping_axis_size:
return x_non_overlapping_axis_size > y_non_overlapping_axis_size
# Otherwise sort by reverse lexical graphical order, to be consistent with
# existing behavior.
return tuple(x) > tuple(y)
def _generate_logical_mesh(
physical_mesh: np.ndarray,
logical_mesh_shape: Sequence[int],
assignment: np.ndarray,
) -> np.ndarray:
"""Compute the logical mesh from assignment map.
Args:
physical_mesh: Physical device mesh.
logical_mesh_shape: Logical mesh shape.
assignment: 2-d assignment matrix shape [physical_dims, logical_dims].
Returns:
Logical mesh reshaped from physical mesh.
"""
physical_indices = np.broadcast_to(
np.expand_dims(
np.arange(len(physical_mesh.shape), dtype=np.int64), axis=-1
),
assignment.shape,
).reshape([-1])
logical_indices = np.broadcast_to(
np.expand_dims(
np.arange(len(logical_mesh_shape), dtype=np.int64), axis=0
),
assignment.shape,
).reshape([-1])
# Axes of logical mesh is ordered by (physical_axis, logical_axis).
#
# Note that we sort for each physical_axis the logical_axis, so that higher
# intensity logical axes are replicated at inner (minor) dimensions.
#
# E.g., if a dimension size is 12 = 3x4, where 3 is higher intensity and 4
# is lower, we want to reshape so that it becomes 12 = 4x3. Imagine in the
# 1-d case, this will allow more connections between the higher intensity
# axes.
logical_mesh = np.reshape(physical_mesh, assignment.reshape([-1]))
# We will then group by l_axis as this is what is expected from output.
_, _, transpose_axes = zip(
*sorted(
zip(logical_indices, physical_indices, range(len(logical_indices)))
)
)
logical_mesh = np.transpose(logical_mesh, transpose_axes)
# Reshape to add the trivial dimensions back.
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape)
return logical_mesh
def _bounds_from_last_device(last_device) -> Sequence[int]:
"""Gets the bound from the given last device."""
# Must be passed the device at the highest-coordinate corner of the
# relevant mesh, which is a requirement we know is satisfied by the last
# device in jax.devices().
assert hasattr(last_device, 'coords'), 'Only TPU supported'
x, y, z = last_device.coords
return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray:
r"""Rearrange TPU devices in a slice into a physical mesh.
Args:
jax_devices: A list of JAX devices in a TPU slice in process-tiled z, y, x,
core order, e.g. from jax.devices().
Returns:
A np.ndarray of JAX devices with shape [global_x, global_y, global_z]. On
v2 and v3, global_z is instead cores_per_chip (i.e., 2).
"""
device_kind = jax_devices[0].device_kind
device_coords = [d.coords for d in jax_devices]
dims = tuple(d + 1 for d in max(device_coords))
assert len(dims) == 3, dims
if device_kind in (_TPU_V2, _TPU_V3):
cores_per_chip = max(d.core_on_chip for d in jax_devices) + 1
out = np.empty(dims[:2] + (cores_per_chip,), dtype=object)
for coords, d in zip(device_coords, jax_devices):
assert coords[2] == 0, d
out[coords[0], coords[1], d.core_on_chip] = d
else:
out = np.empty(dims, dtype=object)
for coords, d in zip(device_coords, jax_devices):
if d.core_on_chip != 0:
raise AssertionError(
'Creating meshes for TPU >v3 requires one device per chip'
f' ("megacore" mode). Got device id {d.core_on_chip} for a device'
f' of kind {device_kind}: {d}.'
)
out[coords[0], coords[1], coords[2]] = d
return out
# jekbradbury's famous trick for creating contiguous submeshes (where available)
def _transpose_trick(
physical_mesh: np.ndarray, mesh_shape: Sequence[int]
) -> np.ndarray:
mesh_shape = tuple(mesh_shape)
topology = physical_mesh.shape
if topology not in _TRANSPOSE_TRICKS:
raise ValueError(
'create_device_mesh cannot create contiguous submeshes for '
f'physical mesh topology {topology}'
)
mesh_shape_no_trivial_dims: tuple[int, ...] = ()
for dim_size in mesh_shape:
if dim_size != 1:
mesh_shape_no_trivial_dims += (dim_size,)
if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]:
raise ValueError(
'create_device_mesh cannot create contiguous submeshes for '
f'mesh_shape {mesh_shape} and physical mesh topology {topology}. '
f'Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}'
)
return physical_mesh.transpose(
*_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims]
)
def create_device_mesh(
mesh_shape: Sequence[int],
devices: Sequence[Any] | None = None,
*,
contiguous_submeshes: bool = False,
allow_split_physical_axes: bool = False,
) -> np.ndarray:
"""Creates a performant device mesh for jax.sharding.Mesh.
Args:
mesh_shape: shape of logical mesh, ordered by increasing network-intensity
e.g. [replica, data, mdl] where mdl has the most network communication
requirements.
devices: optionally, the devices to construct a mesh for. Defaults to
jax.devices().
contiguous_submeshes: if True, this function will attempt to create a mesh
where each process's local devices form a contiguous submesh. A ValueError
will be raised if this function can't produce a suitable mesh. This
setting was sometimes necessary before the introduction of jax.Array to
ensure non-ragged local arrays; if using jax.Arrays, it's better to keep
this set to False.
allow_split_physical_axes: If True, we will split physical axes if necessary
to produce the desired device mesh.
Raises:
ValueError: if the number of devices doesn't equal the product of
`mesh_shape`.
Returns:
A np.ndarray of JAX devices with mesh_shape as its shape that can be fed
into jax.sharding.Mesh with good collective performance.
"""
if devices is None:
devices = xb.devices()
if np.prod(mesh_shape) != len(devices):
raise ValueError(
f'Number of devices {len(devices)} must equal the product '
f'of mesh_shape {mesh_shape}'
)
last_device = devices[-1]
handler = device_kind_handler_dict.get(last_device.device_kind, None)
if handler is not None:
result = handler(
mesh_shape, devices, contiguous_submeshes=contiguous_submeshes
)
if result is not None:
return result
if last_device.platform == 'tpu':
physical_mesh = _get_physical_tpu_mesh(devices)
if contiguous_submeshes:
physical_mesh = _transpose_trick(physical_mesh, mesh_shape)
device_mesh, _ = _create_device_mesh_for_nd_torus(
physical_mesh,
mesh_shape,
allow_split_physical_axes=allow_split_physical_axes,
)
return device_mesh
else:
device_mesh = np.asarray(devices).reshape(mesh_shape)
return device_mesh
def create_hybrid_device_mesh(
mesh_shape: Sequence[int],
dcn_mesh_shape: Sequence[int],
devices: Sequence[Any] | None = None,
*,
process_is_granule: bool = False,
should_sort_granules_by_key: bool = True,
allow_split_physical_axes: bool = False,
) -> np.ndarray:
"""Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
Args:
mesh_shape: shape of the logical mesh for the faster/inner network, ordered
by increasing network intensity, e.g. [replica, data, mdl] where mdl has
the most network communication requirements.
dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in
the same order as mesh_shape.
devices: optionally, the devices to construct a mesh for. Defaults to
jax.devices().
process_is_granule: if True, this function will treat processes as the units
of the slower/outer network. Otherwise it will look for slice_index
attributes on devices and use slices as the units. Enabling this is meant
as a fallback for platforms that don't set slice_index.
should_sort_granules_by_key: Whether device granules should be sorted by the
granule key, either slice or process index, depending on
process_is_granule.
allow_split_physical_axes: If True, we will split physical axes if necessary
to produce the desired device mesh.
Raises:
ValueError: if the number of slices to which the `devices` belong doesn't
equal the product of `dcn_mesh_shape`, or if the number of devices
belonging to any single slice does not equal the product of `mesh_shape`.
Returns:
A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
that can be fed into jax.sharding.Mesh for hybrid parallelism.
"""
if devices is None:
devices = xb.devices()
attr = 'process_index' if process_is_granule else 'slice_index'
assert hasattr(devices[0], attr)
granule_dict = collections.defaultdict(list)
for dev in devices:
granule_dict[getattr(dev, attr)].append(dev)
granules = (
[granule_dict[key] for key in sorted(granule_dict.keys())]
if should_sort_granules_by_key
else granule_dict.values()
)
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
f'Number of slices {len(granules)} must equal the product of '
f'dcn_mesh_shape {dcn_mesh_shape}'
)
per_granule_meshes = [
create_device_mesh(
mesh_shape,
granule,
allow_split_physical_axes=allow_split_physical_axes,
)
for granule in granules
]
# TODO(jekbradbury): handle non-uniform DCN topologies
granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(
granule_mesh
)
device_mesh = np.block(blocks.tolist())
return device_mesh