Add registration handler for TPU v5e in mesh_utils.

PiperOrigin-RevId: 643092629
This commit is contained in:
Yash Katariya 2024-06-13 12:51:43 -07:00 committed by jax authors
parent dd3b0a6981
commit 023bc7856b
2 changed files with 45 additions and 1 deletions

View File

@ -13,9 +13,11 @@ Remember to align the itemized text with the first line of an item within a list
bumped to 0.4.0 but this has been rolled back in this release to give users bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release. release.
* `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
* jax now depends on jaxlib directly. This change was enabled by the CUDA * jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required. a CPU-only jax with `pip install jax`, no extras required.
* Deprecations * Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release. in a future release.

View File

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
_TPU_V2 = 'TPU v2' _TPU_V2 = 'TPU v2'
_TPU_V3 = 'TPU v3' _TPU_V3 = 'TPU v3'
_TPU_V4 = 'TPU v4' _TPU_V4 = 'TPU v4'
_TPU_V5_LITE = "TPU v5 lite"
# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's
# famous contiguous mesh trick. # famous contiguous mesh trick.
@ -64,7 +65,8 @@ _TRANSPOSE_TRICKS: dict[
# Physical ordering of core IDs in a tray that creates a ring # Physical ordering of core IDs in a tray that creates a ring
_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) _TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5)
_TRAY_2x2_RING_ORDER = (0, 1, 3, 2)
_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4)
def _tpu_v2_v3_create_device_mesh( def _tpu_v2_v3_create_device_mesh(
mesh_shape: Sequence[int], mesh_shape: Sequence[int],
@ -94,6 +96,45 @@ def _tpu_v2_v3_create_device_mesh(
return np.asarray(devices).reshape(mesh_shape) return np.asarray(devices).reshape(mesh_shape)
def _vlc_create_device_mesh(
mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs
) -> np.ndarray | None:
"""Creates rotated pincer device assignment for selected topologies.
Args:
mesh_shape: Logical mesh shape used by the model.
devices: TPU devices.
**unused_kwargs: ...
Returns:
None or reordered devices reshaped as `mesh_shape`.
"""
max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices)
bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1
# Our ring re-ordering makes sense only if the passed-in devices are
# sequential, which may not always be the case. reversed() changes z-minor to
# x-minor.
sequential_devices = sorted(
devices,
key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0)))))
if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2
device_mesh = np.asarray(sequential_devices)
device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4
# Only uses ring order if the whole mesh is a replica group.
if max(mesh_shape) == len(devices):
device_mesh = np.asarray(sequential_devices)
device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)]
device_mesh = device_mesh.reshape(mesh_shape)
return device_mesh
return None
# Registers functions to create device mesh for specific device kinds. Takes # Registers functions to create device mesh for specific device kinds. Takes
# precedence over the more general logic in create_device_mesh(). Handler may # precedence over the more general logic in create_device_mesh(). Handler may
# return None; in that case, it will fall back to using the default logic. # return None; in that case, it will fall back to using the default logic.
@ -103,6 +144,7 @@ device_kind_handler_dict: dict[
] = { ] = {
_TPU_V2: _tpu_v2_v3_create_device_mesh, _TPU_V2: _tpu_v2_v3_create_device_mesh,
_TPU_V3: _tpu_v2_v3_create_device_mesh, _TPU_V3: _tpu_v2_v3_create_device_mesh,
_TPU_V5_LITE: _vlc_create_device_mesh,
} }