diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cc510647..87edde806 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,9 +13,11 @@ Remember to align the itemized text with the first line of an item within a list bumped to 0.4.0 but this has been rolled back in this release to give users of both TensorFlow and JAX more time to migrate to a newer TensorFlow release. + * `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e. * jax now depends on jaxlib directly. This change was enabled by the CUDA plugin switch: there are no longer multiple jaxlib variants. You can install a CPU-only jax with `pip install jax`, no extras required. + * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed in a future release. diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 0060a954f..7157be1e2 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -31,6 +31,7 @@ logger = logging.getLogger(__name__) _TPU_V2 = 'TPU v2' _TPU_V3 = 'TPU v3' _TPU_V4 = 'TPU v4' +_TPU_V5_LITE = "TPU v5 lite" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. @@ -64,7 +65,8 @@ _TRANSPOSE_TRICKS: dict[ # Physical ordering of core IDs in a tray that creates a ring _TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) - +_TRAY_2x2_RING_ORDER = (0, 1, 3, 2) +_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) def _tpu_v2_v3_create_device_mesh( mesh_shape: Sequence[int], @@ -94,6 +96,45 @@ def _tpu_v2_v3_create_device_mesh( return np.asarray(devices).reshape(mesh_shape) +def _vlc_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates rotated pincer device assignment for selected topologies. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) + bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 + # Our ring re-ordering makes sense only if the passed-in devices are + # sequential, which may not always be the case. reversed() changes z-minor to + # x-minor. + sequential_devices = sorted( + devices, + key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) + + if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2 + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4 + # Only uses ring order if the whole mesh is a replica group. + if max(mesh_shape) == len(devices): + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + return None + + # Registers functions to create device mesh for specific device kinds. Takes # precedence over the more general logic in create_device_mesh(). Handler may # return None; in that case, it will fall back to using the default logic. @@ -103,6 +144,7 @@ device_kind_handler_dict: dict[ ] = { _TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh, + _TPU_V5_LITE: _vlc_create_device_mesh, }