# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utils for building a device mesh.""" from __future__ import annotations import collections from collections.abc import Sequence import itertools import logging import math from typing import Any, Callable, Generator, MutableMapping from jax._src import xla_bridge as xb import numpy as np logger = logging.getLogger(__name__) _TPU_V2 = 'TPU v2' _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. # # The trick only works for certain topologies and mesh shapes. Trivial dims of # size 1 can be added to the shapes listed, and they are also supported. _TRANSPOSE_TRICKS: dict[ tuple[int, ...], dict[tuple[int, ...], tuple[int, ...]] ] = { (2, 2, 1): { (2, 2): (0, 1, 2), }, (2, 2, 4): { (4, 4): (0, 1, 2), }, (4, 4, 4): { (16, 4): (0, 2, 1), }, (4, 8, 8): { (64, 4): (0, 2, 1), (4, 64): (0, 2, 1), }, (8, 8, 8): { (64, 8): (0, 2, 1), }, (8, 16, 16): { (256, 8): (0, 2, 1), (8, 256): (0, 2, 1), }, } # Physical ordering of core IDs in a tray that creates a ring _TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs, ) -> np.ndarray: if len(devices) == 8: logger.info( 'Reordering mesh to physical ring order on single-tray TPU v2/v3.' ) device_mesh = np.asarray(devices) device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)] device_mesh = device_mesh.reshape(mesh_shape) return device_mesh elif mesh_shape[-1] == 8: device_mesh = np.asarray(devices).reshape(mesh_shape) logger.info( 'Reordering mesh to physical ring order on each TPU v2/v3 tray.' ) perm = np.array(_TRAY_RING_ORDER) device_mesh = device_mesh[..., perm] return device_mesh else: # TODO(skye): implement 2D mesh_shape logic here: # https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187 # (possibly replaces above mesh_shape[-1] == 8 case) return np.asarray(devices).reshape(mesh_shape) # Registers functions to create device mesh for specific device kinds. Takes # precedence over the more general logic in create_device_mesh(). Handler may # return None; in that case, it will fall back to using the default logic. device_kind_handler_dict: dict[ str, Callable[..., np.ndarray | None], ] = { _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, } def _create_device_mesh_for_nd_torus( physical_mesh: np.ndarray, mesh_shape: Sequence[int], *, allow_split_physical_axes: bool = False, ) -> tuple[np.ndarray, np.ndarray]: """Assigns logical parallelism axes to physical axes of an N-D torus network. Given logical parallelism axes with sizes in `mesh_shape` and devices in an N-dimensional torus network represented by `physical_mesh`, maps each logical axis to one or more physical axes. Prefer to map more-performance-sensitive logical axes to larger numbers of physical axes to maximize the bandwidth available to them. Also prefer to assign logical axes to multiple physical axes of the same size (e.g., a 2D square) rather than multiple physical axes of different sizes when possible. If allow_split_physical_axes = False (default), this routine will error out instead of splitting a physical axis over more than one logical axis (which would reduce total usable bandwidth). Let's use a concrete example to explain the concepts and considerations. As an example, suppose the logical mesh is [data, model], for data and model parallelism respectively. Also suppose that data parallelism is less performance sensitive than model parallelism. Consider a 3D TPU pod slice of shape 4x4x16, represented by a physical mesh of shape (4, 4, 16). A TPU pod slice has equal bandwidth along all axes with wraparound links, but a 2D plane of size 4x4 may have faster XLA collective implementations than a non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want the more performance sensitive `model` axis to be mapped to the 4x4 XY plane. Args: physical_mesh: a np.ndarray of devices in the shape of the N-D torus physical topology. mesh_shape: shape of the logical mesh (size of the various logical parallelism axes), with axes ordered by increasing network intensity. allow_split_physical_axes: If True, we would split physical axes if necessary to fit the desired mesh shape. Returns: An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with each logical parallelism axis mapped to one or more physical mesh axes. The axis assignment matrix, which is a 2-d array mapping from (physical_axis, logical_axis) to the size assigned, with the invariant np.prod(assignment, axis=1) = physical_mesh_shape, and np.prod(assignment, axis=0) = mesh_shape. """ # Remaining physical axes to be assigned to logical axes. assignable_physical_mesh = list(physical_mesh.shape) # Map each logical axis to a subset of physical axes. assignment: list[tuple[int, ...]] = [() for _ in mesh_shape] # Assign logical axes from highest network intensity to lowest. # `mesh_shape` is assumed to ordered by lowest network intensity first, so # reverse it first. for logical_axis_index, logical_axis_size in reversed( list(enumerate(mesh_shape)) ): # Preferentially map to more physical axes first for higher bandwidth. for num_axes in range(3, 0, -1): # Try assign to any subset of size num_axes. Generate all candidates. indices_and_axes = itertools.combinations( enumerate(assignable_physical_mesh), num_axes ) for elem in indices_and_axes: c_indices, c_axes = zip(*elem) # TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only # implemented for square 2D plane. Mapping a physical axis to two # logical axes might be slower for non-square 2D plane, e.g., map 32 to # 4x8 or a single axis. If XLA 2D collectives support non-square plane # soon, we can continue to preferentially map to 2D plane in general, # otherwise, we should treat non-square 2D plane and 1D submesh equally. if np.prod(c_axes) == logical_axis_size: assignment[logical_axis_index] = c_indices # Zero the assigned physical axes. assignable_physical_mesh = [ 0 if i in c_indices else v for i, v in enumerate(assignable_physical_mesh) ] break if assignment[logical_axis_index]: # We already found an assignment from one candidate above. break else: # If the num_axes for loop did not break, i.e. none of the candidates work # goto here with this while-else construct. if logical_axis_size > 1: if not allow_split_physical_axes: # Although this is now implemented, there are downstream tasks # counting on this being a NotImplementedError. raise NotImplementedError( 'Failed to find assignment for logical_axis_index' f' {logical_axis_index} of size {logical_axis_size} with' f' remaining assignable mesh {assignable_physical_mesh}. The size' ' of each axis in your logical mesh must be equal to the product' ' of some subset of the physical mesh axis sizes. E.g. logical' ' mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4' ' and 16=4x4. If you want to split physical axes, set ' ' allow_split_physical_axes to True.' ) else: # We will try finding an assignment, even if that means splitting the # physical axes, which requires a more sophisticated implementation. return _create_device_mesh_for_nd_torus_splitting_axes( physical_mesh, mesh_shape ) # Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1). transpose: list[int] = [] assignment_array = np.ones( [len(physical_mesh.shape), len(mesh_shape)], dtype=np.int64 ) for i, x in enumerate(assignment): for y in x: physical_mesh_axis = int(y) assignment_array[physical_mesh_axis, i] = physical_mesh.shape[ physical_mesh_axis ] transpose.append(physical_mesh_axis) return ( physical_mesh.transpose(transpose).reshape(mesh_shape), assignment_array, ) def _create_device_mesh_for_nd_torus_splitting_axes( physical_mesh: np.ndarray, mesh_shape: Sequence[int], ) -> tuple[np.ndarray, np.ndarray]: """Assigns logical parallelism axes to physical axes of an N-D torus network. This implementation allows creating meshes that requires splitting physical axes, and thus one could produce logical mesh of any shape, as long as the number of devices matches, e.g., - Creating 2x2x4 from 4x4; - Creating 2x2x16 from 8x8; Args: physical_mesh: a np.ndarray of devices in the shape of the N-D torus physical topology. mesh_shape: shape of the logical mesh (size of the various logical parallelism axes), with axes ordered by increasing network intensity. Returns: An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with each logical parallelism axis mapped to one or more physical mesh axes. The axis assignment matrix, which is a 2-d array mapping from (physical_axis, logical_axis) to the size assigned, with the invariant np.prod(assignment, axis=1) = physical_mesh_shape, and np.prod(assignment, axis=0) = mesh_shape. """ if np.prod(physical_mesh.shape) != np.prod(mesh_shape): raise ValueError( 'The number of devices in physical mesh' f' {physical_mesh.shape} does not match the number of devices' f' in logical mesh {mesh_shape}.' ) physical_mesh_shape = physical_mesh.shape logical_mesh_shape = tuple(mesh_shape) # (Partial) assignment map as an 2-d array [p_axis, l_axis] -> size. assignment = np.ones( [len(physical_mesh_shape), len(logical_mesh_shape)], dtype=np.int64 ) # Process logical axes from highest network intensity to lowest. # `mesh_shape` is assumed to ordered by lowest network intensity first, so # reverse it. for logical_axis, logical_axis_size in reversed( list(enumerate(logical_mesh_shape)) ): # Go over all the possible assignment for the logical axis, including the # one that splits multiple physical axes. best_logical_axis_assignment = None for logical_axis_assignment in _enumerate_feasible_logical_axis_assignments( physical_mesh_shape, assignment, logical_axis_size ): # TODO(rosun): Instead of using heuristics, replace this with a proper # scoring function reflecting the underlying hardware properties. if ( best_logical_axis_assignment is None or _prefer_first_logical_axis_assignment( logical_axis_assignment, best_logical_axis_assignment, physical_mesh_shape=physical_mesh_shape, assignment=assignment, ) ): best_logical_axis_assignment = logical_axis_assignment assignment[:, logical_axis] = best_logical_axis_assignment # Read out the assignment. logical_mesh = _generate_logical_mesh( physical_mesh, logical_mesh_shape, assignment ) return logical_mesh, assignment def _get_prime_factors(x: int) -> list[int]: """Returns a sorted list of prime factors for the given number.""" assert x > 0 factors = [] for p in range(2, math.isqrt(x) + 2): while x % p == 0: factors.append(p) x //= p if x == 1: return factors else: return [x] # x is a prime number. def _enumerate_feasible_logical_axis_assignments( physical_mesh_shape: Sequence[int], assignment: np.ndarray, logical_axis_size: int, ) -> Generator[np.ndarray, None, None]: """Yields feasible assignments for a single logical axis. For a physical mesh of shape [x_1, ..., x_n], and the product of all previous assignments on each physical axes [y_1, ..., y_n], this function yields all possible assignments for the axis as 1-d arrays [z_1, ..., z_n], so that: - prod(z_1, ..., z_n) = logical_axis_size - x_i % (z_i * y_i) = 0 Args: physical_mesh_shape: Physical mesh shape. assignment: Existing assignment matrix. logical_axis_size: Size of the logical axis to assign. Yields: All valid assignments for the logical axis. Each assignment is represented as an integer array of length len(physical_mesh_shape). """ logical_axis_factors: MutableMapping[int, int] = collections.defaultdict(int) for factor in _get_prime_factors(logical_axis_size): logical_axis_factors[factor] += 1 available_physical_mesh_shape = np.array(physical_mesh_shape) // np.prod( assignment, axis=-1 ) # To enable efficient enumerations, we first index physical axes by their # prime factors. Since we know the prime factorization of the logical axis # size, we could simply enumerate by picking the correct count for each # prime factor. physical_axes_by_factor: MutableMapping[int, list[int]] = ( collections.defaultdict(list) ) for physical_axis, physical_axis_size in enumerate( available_physical_mesh_shape ): for factor in _get_prime_factors(physical_axis_size): if factor not in logical_axis_factors: continue physical_axes_by_factor[factor].append(physical_axis) factors = [] assignments_by_factor = [] for factor, multiplicity in logical_axis_factors.items(): factors.append(factor) assignments_by_factor.append( set( itertools.combinations( physical_axes_by_factor[factor], multiplicity ) ) ) for axis_assignment in itertools.product(*assignments_by_factor): result = np.ones([len(physical_mesh_shape)], dtype=np.int64) for factor_index, per_factor_assignment in enumerate(axis_assignment): for physical_axis in per_factor_assignment: result[physical_axis] *= factors[factor_index] yield result def _prefer_first_logical_axis_assignment( x: np.ndarray, y: np.ndarray, *, physical_mesh_shape: Sequence[int], assignment: np.ndarray, ) -> bool: """Returns True if the first axis assignment is preferred over the second. For now, this is implemented with some very simple heuristics. However, it is possible to introduce e.g., a value function here based on a more precise model of the underlying hardware. TODO(rosun): Use a proxy of network capacity to select the partitions. Args: x: Logical axis assignment as [len(physical_mesh_shape)] array. y: Logical axis assignment as [len(physical_mesh_shape)] array. physical_mesh_shape: Physical mesh shape. assignment: Assignment matrix. Returns: True if x is preferred over y. """ # Prefer occupying complete physical axes. I don't have a good reason for # this, except that it is compatible with the existing behavior. # # E.g., on 4 x 4 x 8, [4, 4, -] will be preferred over [4, -, 4], and then # over [2, 2, 4]. x_whole_axis_size = np.prod( [s for i, s in enumerate(x) if s == physical_mesh_shape[i]] ) y_whole_axis_size = np.prod( [s for i, s in enumerate(y) if s == physical_mesh_shape[i]] ) if x_whole_axis_size != y_whole_axis_size: return x_whole_axis_size > y_whole_axis_size # Prefer occupying more whole physical axes for better bandwidth. # # This is consistent with existing logic, i.e., 2 x 2 is preferred over 4. x_num_whole_axes = len( [1 for i, s in enumerate(x) if s == physical_mesh_shape[i] and s > 1] ) y_num_whole_axes = len( [1 for i, s in enumerate(y) if s == physical_mesh_shape[i] and s > 1] ) if x_num_whole_axes != y_num_whole_axes: return x_num_whole_axes > y_num_whole_axes # Prefer taking physical axes that are not taken by logical axes of higher # network intensity. E.g., for a 4 x 4 x 4, suppose that the previous # assignments are 1 x 2 x 4, and we want to place a new logical axis of size # 2, we will go for [2, 1, 1] instead of [1, 2, 1], as the latter choice will # tap into bandwidth already taken by the higher intensity axis. assigned_physical_mesh_shape = np.prod(assignment, axis=-1) x_non_overlapping_axis_size = np.prod( [s for i, s in enumerate(x) if assigned_physical_mesh_shape[i] > 1] ) y_non_overlapping_axis_size = np.prod( [s for i, s in enumerate(y) if assigned_physical_mesh_shape[i] > 1] ) if x_non_overlapping_axis_size != y_non_overlapping_axis_size: return x_non_overlapping_axis_size > y_non_overlapping_axis_size # Otherwise sort by reverse lexical graphical order, to be consistent with # existing behavior. return tuple(x) > tuple(y) def _generate_logical_mesh( physical_mesh: np.ndarray, logical_mesh_shape: Sequence[int], assignment: np.ndarray, ) -> np.ndarray: """Compute the logical mesh from assignment map. Args: physical_mesh: Physical device mesh. logical_mesh_shape: Logical mesh shape. assignment: 2-d assignment matrix shape [physical_dims, logical_dims]. Returns: Logical mesh reshaped from physical mesh. """ physical_indices = np.broadcast_to( np.expand_dims( np.arange(len(physical_mesh.shape), dtype=np.int64), axis=-1 ), assignment.shape, ).reshape([-1]) logical_indices = np.broadcast_to( np.expand_dims( np.arange(len(logical_mesh_shape), dtype=np.int64), axis=0 ), assignment.shape, ).reshape([-1]) # Axes of logical mesh is ordered by (physical_axis, logical_axis). # # Note that we sort for each physical_axis the logical_axis, so that higher # intensity logical axes are replicated at inner (minor) dimensions. # # E.g., if a dimension size is 12 = 3x4, where 3 is higher intensity and 4 # is lower, we want to reshape so that it becomes 12 = 4x3. Imagine in the # 1-d case, this will allow more connections between the higher intensity # axes. logical_mesh = np.reshape(physical_mesh, assignment.reshape([-1])) # We will then group by l_axis as this is what is expected from output. _, _, transpose_axes = zip( *sorted( zip(logical_indices, physical_indices, range(len(logical_indices))) ) ) logical_mesh = np.transpose(logical_mesh, transpose_axes) # Reshape to add the trivial dimensions back. logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) return logical_mesh def _bounds_from_last_device(last_device) -> Sequence[int]: """Gets the bound from the given last device.""" # Must be passed the device at the highest-coordinate corner of the # relevant mesh, which is a requirement we know is satisfied by the last # device in jax.devices(). assert hasattr(last_device, 'coords'), 'Only TPU supported' x, y, z = last_device.coords return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: r"""Rearrange TPU devices in a slice into a physical mesh. Args: jax_devices: A list of JAX devices in a TPU slice in process-tiled z, y, x, core order, e.g. from jax.devices(). Returns: A np.ndarray of JAX devices with shape [global_x, global_y, global_z]. On v2 and v3, global_z is instead cores_per_chip (i.e., 2). """ device_kind = jax_devices[0].device_kind device_coords = [d.coords for d in jax_devices] dims = tuple(d + 1 for d in max(device_coords)) assert len(dims) == 3, dims if device_kind in (_TPU_V2, _TPU_V3): cores_per_chip = max(d.core_on_chip for d in jax_devices) + 1 out = np.empty(dims[:2] + (cores_per_chip,), dtype=object) for coords, d in zip(device_coords, jax_devices): assert coords[2] == 0, d out[coords[0], coords[1], d.core_on_chip] = d else: out = np.empty(dims, dtype=object) for coords, d in zip(device_coords, jax_devices): if d.core_on_chip != 0: raise AssertionError( 'Creating meshes for TPU >v3 requires one device per chip' f' ("megacore" mode). Got device id {d.core_on_chip} for a device' f' of kind {device_kind}: {d}.' ) out[coords[0], coords[1], coords[2]] = d return out # jekbradbury's famous trick for creating contiguous submeshes (where available) def _transpose_trick( physical_mesh: np.ndarray, mesh_shape: Sequence[int] ) -> np.ndarray: mesh_shape = tuple(mesh_shape) topology = physical_mesh.shape if topology not in _TRANSPOSE_TRICKS: raise ValueError( 'create_device_mesh cannot create contiguous submeshes for ' f'physical mesh topology {topology}' ) mesh_shape_no_trivial_dims: tuple[int, ...] = () for dim_size in mesh_shape: if dim_size != 1: mesh_shape_no_trivial_dims += (dim_size,) if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]: raise ValueError( 'create_device_mesh cannot create contiguous submeshes for ' f'mesh_shape {mesh_shape} and physical mesh topology {topology}. ' f'Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}' ) return physical_mesh.transpose( *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] ) def create_device_mesh( mesh_shape: Sequence[int], devices: Sequence[Any] | None = None, *, contiguous_submeshes: bool = False, allow_split_physical_axes: bool = False, ) -> np.ndarray: """Creates a performant device mesh for jax.sharding.Mesh. Args: mesh_shape: shape of logical mesh, ordered by increasing network-intensity e.g. [replica, data, mdl] where mdl has the most network communication requirements. devices: optionally, the devices to construct a mesh for. Defaults to jax.devices(). contiguous_submeshes: if True, this function will attempt to create a mesh where each process's local devices form a contiguous submesh. A ValueError will be raised if this function can't produce a suitable mesh. This setting was sometimes necessary before the introduction of jax.Array to ensure non-ragged local arrays; if using jax.Arrays, it's better to keep this set to False. allow_split_physical_axes: If True, we will split physical axes if necessary to produce the desired device mesh. Raises: ValueError: if the number of devices doesn't equal the product of `mesh_shape`. Returns: A np.ndarray of JAX devices with mesh_shape as its shape that can be fed into jax.sharding.Mesh with good collective performance. """ if devices is None: devices = xb.devices() if np.prod(mesh_shape) != len(devices): raise ValueError( f'Number of devices {len(devices)} must equal the product ' f'of mesh_shape {mesh_shape}' ) last_device = devices[-1] handler = device_kind_handler_dict.get(last_device.device_kind, None) if handler is not None: result = handler( mesh_shape, devices, contiguous_submeshes=contiguous_submeshes ) if result is not None: return result if last_device.platform == 'tpu': physical_mesh = _get_physical_tpu_mesh(devices) if contiguous_submeshes: physical_mesh = _transpose_trick(physical_mesh, mesh_shape) device_mesh, _ = _create_device_mesh_for_nd_torus( physical_mesh, mesh_shape, allow_split_physical_axes=allow_split_physical_axes, ) return device_mesh else: device_mesh = np.asarray(devices).reshape(mesh_shape) return device_mesh def create_hybrid_device_mesh( mesh_shape: Sequence[int], dcn_mesh_shape: Sequence[int], devices: Sequence[Any] | None = None, *, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = False, ) -> np.ndarray: """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. Args: mesh_shape: shape of the logical mesh for the faster/inner network, ordered by increasing network intensity, e.g. [replica, data, mdl] where mdl has the most network communication requirements. dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in the same order as mesh_shape. devices: optionally, the devices to construct a mesh for. Defaults to jax.devices(). process_is_granule: if True, this function will treat processes as the units of the slower/outer network. Otherwise it will look for slice_index attributes on devices and use slices as the units. Enabling this is meant as a fallback for platforms that don't set slice_index. should_sort_granules_by_key: Whether device granules should be sorted by the granule key, either slice or process index, depending on process_is_granule. allow_split_physical_axes: If True, we will split physical axes if necessary to produce the desired device mesh. Raises: ValueError: if the number of slices to which the `devices` belong doesn't equal the product of `dcn_mesh_shape`, or if the number of devices belonging to any single slice does not equal the product of `mesh_shape`. Returns: A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape that can be fed into jax.sharding.Mesh for hybrid parallelism. """ if devices is None: devices = xb.devices() attr = 'process_index' if process_is_granule else 'slice_index' assert hasattr(devices[0], attr) granule_dict = collections.defaultdict(list) for dev in devices: granule_dict[getattr(dev, attr)].append(dev) granules = ( [granule_dict[key] for key in sorted(granule_dict.keys())] if should_sort_granules_by_key else granule_dict.values() ) if np.prod(dcn_mesh_shape) != len(granules): raise ValueError( f'Number of slices {len(granules)} must equal the product of ' f'dcn_mesh_shape {dcn_mesh_shape}' ) per_granule_meshes = [ create_device_mesh( mesh_shape, granule, allow_split_physical_axes=allow_split_physical_axes, ) for granule in granules ] # TODO(jekbradbury): handle non-uniform DCN topologies granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])( granule_mesh ) device_mesh = np.block(blocks.tolist()) return device_mesh