2023-03-10 10:07:37 -08:00
|
|
|
# Copyright 2018 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Definitions of Mesh and ResourceEnv."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import collections
|
2023-07-21 14:20:39 -04:00
|
|
|
from collections.abc import Hashable, Sequence
|
2023-03-10 10:07:37 -08:00
|
|
|
import contextlib
|
2024-11-14 09:57:24 -08:00
|
|
|
import enum
|
2023-03-10 10:07:37 -08:00
|
|
|
import functools
|
2023-04-13 11:48:11 -07:00
|
|
|
import math
|
2023-03-10 10:07:37 -08:00
|
|
|
import threading
|
2023-07-21 14:20:39 -04:00
|
|
|
from typing import Any, NamedTuple
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from jax._src import config as jax_config
|
|
|
|
from jax._src import xla_bridge as xb
|
2025-02-12 08:23:09 -08:00
|
|
|
from jax._src.util import safe_zip, cache, tuple_delete
|
2023-03-10 10:07:37 -08:00
|
|
|
from jax._src.lib import xla_client as xc
|
2023-08-16 18:40:53 -07:00
|
|
|
|
2025-03-12 20:40:59 -07:00
|
|
|
zip, unsafe_zip = safe_zip, zip
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
MeshAxisName = Any
|
|
|
|
ResourceAxisName = Hashable
|
|
|
|
|
|
|
|
|
|
|
|
def show_axes(axes):
|
|
|
|
return ", ".join(sorted(f"`{a}`" for a in axes))
|
|
|
|
|
|
|
|
|
|
|
|
class ResourceEnv(NamedTuple):
|
|
|
|
physical_mesh: Mesh
|
|
|
|
|
|
|
|
def with_mesh(self, mesh: Mesh):
|
|
|
|
overlap = set(mesh.axis_names) & (self.resource_axes - set(self.physical_mesh.axis_names))
|
|
|
|
if overlap:
|
|
|
|
raise ValueError(f"Cannot update the mesh of the current resource "
|
|
|
|
f"environment. The new mesh shadows already defined axes "
|
|
|
|
f"{show_axes(overlap)}")
|
|
|
|
return self._replace(physical_mesh=mesh)
|
|
|
|
|
|
|
|
@property
|
2023-06-23 15:11:37 -07:00
|
|
|
def physical_resource_axes(self) -> set[ResourceAxisName]:
|
2023-03-10 10:07:37 -08:00
|
|
|
return set(self.physical_mesh.axis_names)
|
|
|
|
|
|
|
|
@property
|
2023-06-23 15:11:37 -07:00
|
|
|
def resource_axes(self) -> set[ResourceAxisName]:
|
2024-07-25 14:07:51 -07:00
|
|
|
return self.physical_resource_axes
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self):
|
2024-07-25 14:07:51 -07:00
|
|
|
return self.physical_mesh.shape
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def local_shape(self):
|
2024-07-25 14:07:51 -07:00
|
|
|
return self.physical_mesh.local_mesh.shape
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
def __repr__(self):
|
2023-10-27 14:24:49 -07:00
|
|
|
mesh_repr = ", ".join(
|
|
|
|
f"'{k}': {v}" for k, v in self.physical_mesh.shape.items())
|
2024-07-25 14:07:51 -07:00
|
|
|
return f"ResourceEnv(mesh=Mesh({mesh_repr}))"
|
2023-03-10 10:07:37 -08:00
|
|
|
|
2023-08-28 15:03:18 -07:00
|
|
|
|
2025-02-12 08:23:09 -08:00
|
|
|
@cache(max_size=128, trace_context_in_key=False)
|
2023-08-29 12:17:37 -07:00
|
|
|
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
|
|
|
if global_mesh.empty:
|
2025-01-08 01:38:55 -08:00
|
|
|
return global_mesh
|
2023-08-29 12:17:37 -07:00
|
|
|
is_local_device = np.vectorize(
|
|
|
|
lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
|
|
|
|
subcube_indices = []
|
|
|
|
# We take the smallest slice of each dimension that doesn't skip any local device.
|
|
|
|
for axis in range(global_mesh.devices.ndim):
|
2025-02-12 08:23:09 -08:00
|
|
|
other_axes = tuple_delete(tuple(range(global_mesh.devices.ndim)), axis)
|
2023-08-29 12:17:37 -07:00
|
|
|
# NOTE: This re-reduces over many axes multiple times, so we could definitely
|
|
|
|
# optimize it, but I hope it won't be a bottleneck anytime soon.
|
|
|
|
local_slices = is_local_device.any(other_axes, keepdims=False)
|
|
|
|
nonzero_indices = np.flatnonzero(local_slices)
|
|
|
|
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
|
|
|
|
subcube_indices.append(slice(start, end + 1))
|
2023-08-29 13:25:12 -07:00
|
|
|
subcube_indices_tuple = tuple(subcube_indices)
|
2023-08-29 12:17:37 -07:00
|
|
|
# We only end up with all conditions being true if the local devices formed a
|
|
|
|
# subcube of the full array. This is because we were biased towards taking a
|
|
|
|
# "hull" spanned by the devices, and in case the local devices don't form a
|
|
|
|
# subcube that hull will contain non-local devices.
|
2023-08-29 13:25:12 -07:00
|
|
|
if not is_local_device[subcube_indices_tuple].all():
|
2023-08-29 12:17:37 -07:00
|
|
|
raise ValueError(
|
2025-01-08 01:38:55 -08:00
|
|
|
"When passing host local inputs to pjit, devices connected to a single"
|
|
|
|
" host must form a contiguous subcube of the global device mesh"
|
|
|
|
)
|
2023-08-29 13:25:12 -07:00
|
|
|
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)
|
2023-08-29 12:17:37 -07:00
|
|
|
|
|
|
|
|
2025-03-14 11:47:33 -07:00
|
|
|
class AxisType(enum.Enum):
|
2025-01-24 23:19:51 -08:00
|
|
|
Auto = enum.auto()
|
|
|
|
Explicit = enum.auto()
|
|
|
|
Manual = enum.auto()
|
2024-11-14 09:57:24 -08:00
|
|
|
|
2024-11-20 13:06:39 -08:00
|
|
|
def __repr__(self):
|
|
|
|
return self.name
|
|
|
|
|
2025-04-06 23:37:20 -07:00
|
|
|
def _normalize_axis_types(axis_names, axis_types, name):
|
2025-03-14 11:47:33 -07:00
|
|
|
axis_types = ((AxisType.Auto,) * len(axis_names)
|
2025-03-12 20:40:59 -07:00
|
|
|
if axis_types is None else axis_types)
|
|
|
|
if not isinstance(axis_types, tuple):
|
|
|
|
axis_types = (axis_types,)
|
2025-04-06 23:37:20 -07:00
|
|
|
|
|
|
|
if not all(isinstance(a, AxisType) for a in axis_types):
|
|
|
|
raise TypeError(
|
|
|
|
f"axis_types passed to {name} must be of type `jax.sharding.AxisType`."
|
|
|
|
f" Got {axis_types} of type {tuple(type(a) for a in axis_types)}")
|
2025-03-12 20:40:59 -07:00
|
|
|
if len(axis_names) != len(axis_types):
|
2025-02-12 08:23:09 -08:00
|
|
|
raise ValueError(
|
2025-03-12 20:40:59 -07:00
|
|
|
"Number of axis names should match the number of axis_types. Got"
|
|
|
|
f" axis_names={axis_names} and axis_types={axis_types}")
|
|
|
|
return axis_types
|
2025-01-28 11:04:05 -08:00
|
|
|
|
2025-03-14 11:47:33 -07:00
|
|
|
def all_axis_types_match(axis_types, ty: AxisType) -> bool:
|
2025-02-12 13:56:42 -08:00
|
|
|
if not axis_types:
|
|
|
|
return False
|
2025-03-12 20:40:59 -07:00
|
|
|
return all(t == ty for t in axis_types)
|
2025-02-12 13:56:42 -08:00
|
|
|
|
2025-03-14 11:47:33 -07:00
|
|
|
def any_axis_types_match(axis_types, ty: AxisType) -> bool:
|
2025-02-12 13:56:42 -08:00
|
|
|
if not axis_types:
|
|
|
|
return False
|
2025-03-12 20:40:59 -07:00
|
|
|
return any(t == ty for t in axis_types)
|
2024-11-14 09:57:24 -08:00
|
|
|
|
2025-02-12 14:14:11 -08:00
|
|
|
|
|
|
|
class _BaseMesh:
|
|
|
|
axis_names: tuple[MeshAxisName, ...]
|
|
|
|
shape_tuple: tuple[tuple[str, int], ...]
|
2025-03-14 11:47:33 -07:00
|
|
|
_axis_types: tuple[AxisType, ...]
|
2025-03-12 20:40:59 -07:00
|
|
|
|
|
|
|
@property
|
2025-03-14 11:47:33 -07:00
|
|
|
def axis_types(self) -> tuple[AxisType, ...]:
|
2025-03-12 20:40:59 -07:00
|
|
|
return self._axis_types
|
2025-02-12 14:14:11 -08:00
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _are_all_axes_manual(self) -> bool:
|
2025-03-14 11:47:33 -07:00
|
|
|
return all_axis_types_match(self._axis_types, AxisType.Manual)
|
2025-02-12 14:14:11 -08:00
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _are_all_axes_auto(self) -> bool:
|
2025-03-14 11:47:33 -07:00
|
|
|
return all_axis_types_match(self._axis_types, AxisType.Auto)
|
2025-02-12 14:14:11 -08:00
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _are_all_axes_explicit(self) -> bool:
|
2025-03-14 11:47:33 -07:00
|
|
|
return all_axis_types_match(self._axis_types, AxisType.Explicit)
|
2025-02-12 14:14:11 -08:00
|
|
|
|
2025-02-19 20:04:19 -08:00
|
|
|
@functools.cached_property
|
|
|
|
def _are_all_axes_auto_or_manual(self) -> bool:
|
2025-03-12 20:40:59 -07:00
|
|
|
if not self._axis_types:
|
2025-02-19 20:04:19 -08:00
|
|
|
return False
|
2025-03-14 11:47:33 -07:00
|
|
|
return all(t == AxisType.Auto or t == AxisType.Manual
|
2025-03-12 20:40:59 -07:00
|
|
|
for t in self._axis_types)
|
2025-02-19 20:04:19 -08:00
|
|
|
|
2025-02-12 14:14:11 -08:00
|
|
|
@functools.cached_property
|
|
|
|
def _any_axis_manual(self) -> bool:
|
2025-03-14 11:47:33 -07:00
|
|
|
return any_axis_types_match(self._axis_types, AxisType.Manual)
|
2025-02-12 14:14:11 -08:00
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _any_axis_auto(self) -> bool:
|
2025-03-14 11:47:33 -07:00
|
|
|
return any_axis_types_match(self._axis_types, AxisType.Auto)
|
2025-02-12 14:14:11 -08:00
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _any_axis_explicit(self) -> bool:
|
2025-03-14 11:47:33 -07:00
|
|
|
return any_axis_types_match(self._axis_types, AxisType.Explicit)
|
2025-02-12 14:14:11 -08:00
|
|
|
|
2025-04-03 18:34:35 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def auto_axes(self):
|
|
|
|
return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types)
|
|
|
|
if t == AxisType.Auto)
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def explicit_axes(self):
|
|
|
|
return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types)
|
|
|
|
if t == AxisType.Explicit)
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def manual_axes(self):
|
|
|
|
return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types)
|
|
|
|
if t == AxisType.Manual)
|
|
|
|
|
2025-02-12 14:14:11 -08:00
|
|
|
@functools.cached_property
|
2025-03-12 20:40:59 -07:00
|
|
|
def _axis_types_dict(self):
|
2025-02-19 20:04:19 -08:00
|
|
|
if not self.axis_names:
|
|
|
|
return {}
|
2025-02-12 14:14:11 -08:00
|
|
|
d = collections.defaultdict(list)
|
2025-03-12 20:40:59 -07:00
|
|
|
for n, t in safe_zip(self.axis_names, self._axis_types):
|
2025-02-12 14:14:11 -08:00
|
|
|
d[t].append(n)
|
|
|
|
return {t: tuple(n) for t, n in d.items()}
|
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def _name_to_type(self):
|
2025-03-12 20:40:59 -07:00
|
|
|
return dict(safe_zip(self.axis_names, self._axis_types))
|
2025-02-12 14:14:11 -08:00
|
|
|
|
|
|
|
|
2023-08-28 15:03:18 -07:00
|
|
|
_mesh_object_dict = {} # type: ignore
|
|
|
|
|
|
|
|
|
2025-02-12 14:14:11 -08:00
|
|
|
class Mesh(_BaseMesh, contextlib.ContextDecorator):
|
2023-03-10 10:07:37 -08:00
|
|
|
"""Declare the hardware resources available in the scope of this manager.
|
|
|
|
|
2025-03-24 16:05:45 -07:00
|
|
|
See the Distributed arrays and automatic parallelization tutorial
|
2025-04-08 08:32:59 -07:00
|
|
|
(https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
|
2025-03-24 16:05:45 -07:00
|
|
|
and Explicit sharding tutorial (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
devices: A NumPy ndarray object containing JAX device objects (as
|
|
|
|
obtained e.g. from :py:func:`jax.devices`).
|
|
|
|
axis_names: A sequence of resource axis names to be assigned to the
|
|
|
|
dimensions of the ``devices`` argument. Its length should match the
|
|
|
|
rank of ``devices``.
|
|
|
|
|
2024-06-21 11:28:35 -04:00
|
|
|
Examples:
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
>>> from jax.sharding import Mesh
|
2025-03-24 16:05:45 -07:00
|
|
|
>>> from jax.sharding import PartitionSpec as P, NamedSharding
|
2023-03-10 10:07:37 -08:00
|
|
|
>>> import numpy as np
|
|
|
|
...
|
|
|
|
>>> # Declare a 2D mesh with axes `x` and `y`.
|
2025-03-24 16:05:45 -07:00
|
|
|
>>> devices = np.array(jax.devices()).reshape(4, 2)
|
|
|
|
>>> mesh = Mesh(devices, ('x', 'y'))
|
|
|
|
>>> inp = np.arange(16).reshape(8, 2)
|
|
|
|
>>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y')))
|
|
|
|
>>> out = jax.jit(lambda x: x * 2)(arr)
|
|
|
|
>>> assert out.sharding == NamedSharding(mesh, P('x', 'y'))
|
2023-03-10 10:07:37 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
devices: np.ndarray
|
2023-06-23 15:11:37 -07:00
|
|
|
axis_names: tuple[MeshAxisName, ...]
|
2023-03-10 10:07:37 -08:00
|
|
|
|
2023-08-28 15:03:18 -07:00
|
|
|
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
|
2024-12-10 18:02:42 -08:00
|
|
|
axis_names: str | Sequence[MeshAxisName], *,
|
2025-03-14 11:47:33 -07:00
|
|
|
axis_types: tuple[AxisType, ...] | None = None):
|
2023-03-10 10:07:37 -08:00
|
|
|
if not isinstance(devices, np.ndarray):
|
|
|
|
devices = np.array(devices)
|
|
|
|
if isinstance(axis_names, str):
|
|
|
|
axis_names = (axis_names,)
|
2023-08-28 15:03:18 -07:00
|
|
|
axis_names = tuple(axis_names)
|
2024-10-19 15:48:05 -07:00
|
|
|
if any(i is None for i in axis_names):
|
2024-10-16 10:29:31 -07:00
|
|
|
raise ValueError(f"Mesh axis names cannot be None. Got: {axis_names}")
|
2023-12-06 16:30:13 -08:00
|
|
|
|
|
|
|
if devices.ndim != len(axis_names):
|
|
|
|
raise ValueError(
|
|
|
|
"Mesh requires the ndim of its first argument (`devices`) to equal "
|
|
|
|
"the length of its second argument (`axis_names`), but got "
|
|
|
|
f"devices.ndim == {devices.ndim} and "
|
|
|
|
f"len(axis_names) == {len(axis_names)}.")
|
2023-08-28 15:03:18 -07:00
|
|
|
|
2025-04-06 23:37:20 -07:00
|
|
|
axis_types = _normalize_axis_types(axis_names, axis_types, 'Mesh')
|
2025-01-16 17:55:15 -08:00
|
|
|
|
2025-03-12 20:40:59 -07:00
|
|
|
key = (axis_names, devices.shape, tuple(devices.flat), axis_types)
|
2023-08-29 12:17:37 -07:00
|
|
|
val = _mesh_object_dict.get(key, None)
|
|
|
|
if val is not None:
|
|
|
|
return val
|
2023-08-28 15:03:18 -07:00
|
|
|
|
2023-12-08 12:09:04 +00:00
|
|
|
self = super().__new__(cls)
|
2023-03-10 10:07:37 -08:00
|
|
|
self.devices = devices.copy()
|
|
|
|
self.devices.flags.writeable = False
|
2023-08-28 15:03:18 -07:00
|
|
|
self.axis_names = axis_names
|
2025-03-12 20:40:59 -07:00
|
|
|
self._axis_types = axis_types
|
2025-02-07 12:20:08 -08:00
|
|
|
self._size = math.prod(self.shape.values()) if self.devices.ndim else 0
|
2023-08-29 12:17:37 -07:00
|
|
|
_mesh_object_dict[key] = self
|
2023-08-28 15:03:18 -07:00
|
|
|
return self
|
|
|
|
|
|
|
|
def __reduce__(self):
|
2024-12-10 18:02:42 -08:00
|
|
|
return (type(self), (self.devices, self.axis_names),
|
2025-03-12 20:40:59 -07:00
|
|
|
{'axis_types': self._axis_types})
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
# This is a performance optimization. Comparing thousands of devices
|
|
|
|
# can be expensive.
|
2025-03-14 08:53:53 -07:00
|
|
|
if self is other:
|
2023-03-10 10:07:37 -08:00
|
|
|
return True
|
2025-02-07 12:20:08 -08:00
|
|
|
if not isinstance(other, Mesh):
|
|
|
|
return False
|
2023-11-17 09:37:45 -08:00
|
|
|
return (self.axis_names == other.axis_names and
|
|
|
|
self.devices.shape == other.devices.shape and
|
2025-03-12 20:40:59 -07:00
|
|
|
self._axis_types == other._axis_types and
|
2023-11-17 09:37:45 -08:00
|
|
|
self._internal_device_list == other._internal_device_list)
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
if not hasattr(self, '_hash'):
|
2023-11-17 09:37:45 -08:00
|
|
|
self._hash = hash(
|
2024-11-14 09:57:24 -08:00
|
|
|
(self.axis_names, self._internal_device_list, self.devices.shape,
|
2025-03-12 20:40:59 -07:00
|
|
|
self._axis_types))
|
2023-03-10 10:07:37 -08:00
|
|
|
return self._hash
|
|
|
|
|
|
|
|
def __setattr__(self, name, value):
|
|
|
|
if hasattr(self, name):
|
2023-11-22 05:47:17 -08:00
|
|
|
if getattr(self, name) == value:
|
|
|
|
# This can to happen if two threads race, for example if two threads
|
|
|
|
# are trying to hash the same Mesh instance.
|
|
|
|
return
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Cannot reassign attributes ({name}) of immutable mesh objects"
|
|
|
|
)
|
2023-03-10 10:07:37 -08:00
|
|
|
super().__setattr__(name, value)
|
|
|
|
|
|
|
|
def __enter__(self):
|
2024-09-19 10:41:58 -07:00
|
|
|
if jax_config.disallow_mesh_context_manager.value:
|
|
|
|
raise RuntimeError("Mesh context manager is disabled.")
|
2023-03-10 10:07:37 -08:00
|
|
|
new_env = thread_resources.stack[-1].with_mesh(self)
|
|
|
|
thread_resources.stack.append(new_env)
|
|
|
|
thread_resources.env = new_env
|
2024-11-05 08:31:12 -08:00
|
|
|
jax_config.mesh_context_manager.set_local(
|
|
|
|
tuple(t.physical_mesh for t in thread_resources.stack
|
|
|
|
if not t.physical_mesh.empty))
|
2023-03-10 10:07:37 -08:00
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
thread_resources.stack.pop()
|
|
|
|
thread_resources.env = thread_resources.stack[-1]
|
2024-11-05 08:31:12 -08:00
|
|
|
jax_config.mesh_context_manager.set_local(
|
|
|
|
tuple(t.physical_mesh for t in thread_resources.stack
|
|
|
|
if not t.physical_mesh.empty))
|
2023-03-10 10:07:37 -08:00
|
|
|
return False
|
|
|
|
|
2025-02-16 21:42:15 -08:00
|
|
|
@functools.cached_property
|
2023-03-10 10:07:37 -08:00
|
|
|
def shape(self):
|
|
|
|
return collections.OrderedDict(
|
|
|
|
(name, size)
|
2025-02-12 08:23:09 -08:00
|
|
|
for name, size in safe_zip(self.axis_names, self.devices.shape))
|
2023-03-10 10:07:37 -08:00
|
|
|
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def shape_tuple(self):
|
|
|
|
return tuple(
|
|
|
|
(name, size)
|
2025-02-12 08:23:09 -08:00
|
|
|
for name, size in safe_zip(self.axis_names, self.devices.shape))
|
|
|
|
|
2024-10-14 10:07:08 -07:00
|
|
|
@property
|
|
|
|
def axis_sizes(self) -> tuple[int, ...]:
|
|
|
|
return self.devices.shape
|
|
|
|
|
2023-03-10 10:07:37 -08:00
|
|
|
@property
|
|
|
|
def size(self):
|
2025-02-07 12:20:08 -08:00
|
|
|
return self._size
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def empty(self):
|
2024-08-20 09:58:09 -07:00
|
|
|
return self.size == 0
|
2023-03-10 10:07:37 -08:00
|
|
|
|
2025-02-26 18:16:45 -08:00
|
|
|
# TODO(emilyaf): Remove this when the `enable_empty_arrays` flag is
|
|
|
|
# removed.
|
2023-04-06 19:45:59 -07:00
|
|
|
@functools.cached_property
|
2023-03-10 10:07:37 -08:00
|
|
|
def is_multi_process(self):
|
|
|
|
return self.devices.size != len(self.local_devices)
|
|
|
|
|
2025-02-26 18:16:45 -08:00
|
|
|
@functools.cached_property
|
|
|
|
def _process_indices(self):
|
|
|
|
return {d.process_index for d in self._flat_devices_tuple}
|
|
|
|
|
2023-08-29 12:17:37 -07:00
|
|
|
@property
|
2023-03-10 10:07:37 -08:00
|
|
|
def local_mesh(self):
|
|
|
|
return self._local_mesh(xb.process_index())
|
|
|
|
|
|
|
|
def _local_mesh(self, process_index):
|
2023-08-29 12:17:37 -07:00
|
|
|
return _get_local_mesh(self, process_index)
|
2023-03-10 10:07:37 -08:00
|
|
|
|
2023-04-07 12:05:31 -07:00
|
|
|
@functools.cached_property
|
2023-03-10 10:07:37 -08:00
|
|
|
def device_ids(self):
|
|
|
|
assert not self.empty
|
|
|
|
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
|
|
|
|
|
2023-04-07 13:54:04 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def _local_devices_set(self):
|
|
|
|
return set(self.local_devices)
|
|
|
|
|
|
|
|
@functools.cached_property
|
2023-04-13 08:02:53 -07:00
|
|
|
def _flat_devices_tuple(self):
|
|
|
|
return tuple(self.devices.flat)
|
2023-04-07 13:54:04 -07:00
|
|
|
|
2023-08-14 18:10:47 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def _internal_device_list(self):
|
|
|
|
return xc.DeviceList(self._flat_devices_tuple)
|
|
|
|
|
2023-04-07 13:54:04 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def _flat_devices_set(self):
|
|
|
|
return set(self.devices.flat)
|
|
|
|
|
2024-01-18 11:22:52 -08:00
|
|
|
def __str__(self):
|
|
|
|
mesh_str = ", ".join(f"'{k}': {v}" for k, v in self.shape.items())
|
2025-03-12 20:40:59 -07:00
|
|
|
atr = f", axis_types={self._axis_types}"
|
|
|
|
return f"Mesh({mesh_str}{atr})"
|
2024-01-18 11:22:52 -08:00
|
|
|
|
2023-04-07 13:54:04 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def _repr(self):
|
2023-03-10 10:07:37 -08:00
|
|
|
if self.empty:
|
|
|
|
return "Mesh(device_ids=[], axis_names=())"
|
2025-03-12 20:40:59 -07:00
|
|
|
atr = f", axis_types={self._axis_types}"
|
2024-11-14 09:57:24 -08:00
|
|
|
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"
|
2023-03-10 10:07:37 -08:00
|
|
|
|
2023-04-07 13:54:04 -07:00
|
|
|
def __repr__(self):
|
|
|
|
return self._repr
|
|
|
|
|
2023-03-10 10:07:37 -08:00
|
|
|
@functools.cached_property
|
|
|
|
def local_devices(self):
|
|
|
|
return [d for d in self.devices.flat
|
|
|
|
if d.process_index == d.client.process_index()]
|
|
|
|
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def abstract_mesh(self):
|
Make the signature of AbstractMesh to be `AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)` instead of `AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)` so that we are consistent across all Mesh APIs: `Mesh`, `AbstractMesh` and `make_mesh`
PiperOrigin-RevId: 736371111
2025-03-12 21:31:51 -07:00
|
|
|
return AbstractMesh(self.axis_sizes, self.axis_names,
|
|
|
|
axis_types=self._axis_types)
|
2025-02-21 18:48:26 -08:00
|
|
|
|
2023-03-10 10:07:37 -08:00
|
|
|
|
2024-07-25 14:07:51 -07:00
|
|
|
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
|
2023-03-10 10:07:37 -08:00
|
|
|
|
|
|
|
class _ThreadResourcesLocalState(threading.local):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.stack = [EMPTY_ENV]
|
|
|
|
self.env = self.stack[-1]
|
|
|
|
|
|
|
|
thread_resources = _ThreadResourcesLocalState()
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
|
|
|
|
2025-02-12 14:14:11 -08:00
|
|
|
class AbstractMesh(_BaseMesh):
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
"""AbstractMesh contains only axis names and axis sizes.
|
|
|
|
|
|
|
|
It does not contain concrete devices compared to `jax.sharding.Mesh`. You
|
|
|
|
should use this as an input to the sharding passed to with_sharding_constraint
|
|
|
|
and mesh passed to shard_map to avoid tracing and lowering cache misses when
|
2024-12-12 17:22:05 +02:00
|
|
|
your mesh shape and axis names stay the same but the devices change.
|
2024-09-20 07:51:48 -07:00
|
|
|
See the description of https://github.com/jax-ml/jax/pull/23022 for more
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
details.
|
|
|
|
"""
|
|
|
|
|
Make the signature of AbstractMesh to be `AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)` instead of `AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)` so that we are consistent across all Mesh APIs: `Mesh`, `AbstractMesh` and `make_mesh`
PiperOrigin-RevId: 736371111
2025-03-12 21:31:51 -07:00
|
|
|
def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...],
|
2025-03-14 11:47:33 -07:00
|
|
|
*, axis_types: AxisType | tuple[AxisType, ...] | None = None):
|
Make the signature of AbstractMesh to be `AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)` instead of `AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)` so that we are consistent across all Mesh APIs: `Mesh`, `AbstractMesh` and `make_mesh`
PiperOrigin-RevId: 736371111
2025-03-12 21:31:51 -07:00
|
|
|
self.axis_sizes = axis_sizes
|
|
|
|
self.axis_names = axis_names
|
|
|
|
self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0
|
2025-04-06 23:37:20 -07:00
|
|
|
self._axis_types = _normalize_axis_types(
|
|
|
|
self.axis_names, axis_types, 'AbstractMesh')
|
2025-03-13 15:00:41 -07:00
|
|
|
self._hash = hash((self.axis_sizes, self.axis_names, self._axis_types))
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
|
|
|
def __hash__(self):
|
2025-03-13 15:00:41 -07:00
|
|
|
return self._hash
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
|
|
|
def __eq__(self, other):
|
2025-03-14 08:53:53 -07:00
|
|
|
if self is other:
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
return True
|
2025-02-07 12:20:08 -08:00
|
|
|
if not isinstance(other, AbstractMesh):
|
|
|
|
return False
|
2025-03-13 15:00:41 -07:00
|
|
|
return (self.axis_sizes == other.axis_sizes and
|
|
|
|
self.axis_names == other.axis_names and
|
2025-03-12 20:40:59 -07:00
|
|
|
self._axis_types == other._axis_types)
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
|
|
|
def __repr__(self):
|
2025-01-28 11:04:05 -08:00
|
|
|
mesh_repr = (", ".join(f"'{n}': {v}" for n, v in self.shape_tuple)
|
|
|
|
if self.shape_tuple else "()")
|
2025-03-12 20:40:59 -07:00
|
|
|
atr = f", axis_types={self._axis_types}"
|
2024-12-17 09:16:38 -08:00
|
|
|
return f"AbstractMesh({mesh_repr}{atr})"
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
2025-02-07 12:20:08 -08:00
|
|
|
@property
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
def size(self):
|
2025-02-07 12:20:08 -08:00
|
|
|
return self._size
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
|
|
|
|
@functools.cached_property
|
|
|
|
def shape(self):
|
|
|
|
return collections.OrderedDict(self.shape_tuple)
|
|
|
|
|
Make the signature of AbstractMesh to be `AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)` instead of `AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)` so that we are consistent across all Mesh APIs: `Mesh`, `AbstractMesh` and `make_mesh`
PiperOrigin-RevId: 736371111
2025-03-12 21:31:51 -07:00
|
|
|
@functools.cached_property
|
|
|
|
def shape_tuple(self):
|
|
|
|
return tuple(
|
|
|
|
(name, size)
|
|
|
|
for name, size in safe_zip(self.axis_names, self.axis_sizes))
|
|
|
|
|
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`).
**Semantics**
Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).
Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.
**Why do this?**
There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.
So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # DEVICE MISMATCH ERROR!
```
The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.
**Okay, so how do you fix this?**
As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)
The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.
**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**
**What about `shard_map`?**
shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.
```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')
arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))
# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)
@jax.jit
def f(x):
y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
return y * 2
f(arr_mesh1)
f(arr_mesh2) # tracing and lowering cache hit
```
This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!
PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
|
|
|
@property
|
|
|
|
def _internal_device_list(self):
|
|
|
|
return None
|
|
|
|
|
2024-08-13 17:37:27 -07:00
|
|
|
@property
|
|
|
|
def empty(self):
|
|
|
|
return self.size == 0
|
|
|
|
|
2025-01-14 08:03:08 -08:00
|
|
|
@property
|
|
|
|
def abstract_mesh(self):
|
|
|
|
return self
|
|
|
|
|
2025-03-14 11:47:33 -07:00
|
|
|
def update_axis_types(self, name_to_type: dict[MeshAxisName, AxisType]):
|
2025-03-12 20:40:59 -07:00
|
|
|
new_axis_types = tuple(name_to_type[n] if n in name_to_type else a
|
|
|
|
for n, a in zip(self.axis_names, self._axis_types))
|
Make the signature of AbstractMesh to be `AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)` instead of `AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)` so that we are consistent across all Mesh APIs: `Mesh`, `AbstractMesh` and `make_mesh`
PiperOrigin-RevId: 736371111
2025-03-12 21:31:51 -07:00
|
|
|
return AbstractMesh(self.axis_sizes, self.axis_names,
|
|
|
|
axis_types=new_axis_types)
|
2025-02-21 18:48:26 -08:00
|
|
|
|
2024-08-13 17:37:27 -07:00
|
|
|
@property
|
|
|
|
def devices(self):
|
|
|
|
_raise_value_error("devices")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def device_ids(self):
|
|
|
|
_raise_value_error("device_ids")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_multi_process(self):
|
|
|
|
_raise_value_error("is_multi_process")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def local_devices(self):
|
|
|
|
_raise_value_error("local_devices")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def local_mesh(self):
|
|
|
|
_raise_value_error("local_mesh")
|
|
|
|
|
2024-12-04 15:57:20 -08:00
|
|
|
def __enter__(self):
|
|
|
|
_raise_value_error("__enter__")
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
_raise_value_error("__exit__")
|
|
|
|
|
2024-10-06 14:49:43 -07:00
|
|
|
@staticmethod
|
|
|
|
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
|
2025-02-04 16:35:32 -08:00
|
|
|
prev = jax_config.abstract_mesh_context_manager.swap_local(mesh)
|
|
|
|
return prev
|
2024-10-06 14:49:43 -07:00
|
|
|
|
2024-08-13 17:37:27 -07:00
|
|
|
|
|
|
|
# Create this indirection because pytype fails to recognize a property if a
|
|
|
|
# property raises an exception unconditionally. Remove this once that is fixed.
|
|
|
|
def _raise_value_error(name):
|
|
|
|
raise ValueError(f"AbstractMesh does not implement {name}")
|
2024-11-20 13:06:39 -08:00
|
|
|
|
2025-03-12 22:29:08 -07:00
|
|
|
empty_abstract_mesh = AbstractMesh((), ())
|
|
|
|
|
|
|
|
class UseAbstractMeshContextManager:
|
2025-02-25 10:29:39 -08:00
|
|
|
__slots__ = ['mesh', 'prev']
|
|
|
|
|
|
|
|
def __init__(self, mesh: AbstractMesh):
|
2025-03-18 16:56:50 -07:00
|
|
|
if not isinstance(mesh, AbstractMesh):
|
|
|
|
raise ValueError(
|
|
|
|
"Expected mesh of type `jax.sharding.AbstractMesh`. Got type:"
|
|
|
|
f" {type(mesh)}")
|
2025-02-25 10:29:39 -08:00
|
|
|
self.mesh = mesh
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
self.prev = jax_config.abstract_mesh_context_manager.swap_local(self.mesh)
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
jax_config.abstract_mesh_context_manager.set_local(self.prev)
|
|
|
|
|
2025-03-12 22:29:08 -07:00
|
|
|
use_abstract_mesh = UseAbstractMeshContextManager
|
2025-01-28 11:04:05 -08:00
|
|
|
|
2024-12-04 14:03:45 -08:00
|
|
|
def get_abstract_mesh():
|
2025-01-28 11:04:05 -08:00
|
|
|
val = jax_config.abstract_mesh_context_manager.value
|
|
|
|
return empty_abstract_mesh if val is None else val
|
2024-11-25 18:29:52 -08:00
|
|
|
|
2025-03-14 18:57:39 -07:00
|
|
|
def get_concrete_mesh() -> Mesh | None:
|
2024-12-04 14:03:45 -08:00
|
|
|
return jax_config.device_context.value
|