rocm_jax/jax/_src/named_sharding.py
Yash Katariya 88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00

573 lines
21 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import collections
import dataclasses
import functools
from typing import Any, Union
from jax._src import config
from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import sdy
from jax._src import mesh as mesh_lib
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
from jax._src import sharding as JSharding
from jax._src import xla_bridge as xb
import numpy as np
Shape = tuple[int, ...]
Device = xc.Device
Index = tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
class AUTO:
def __init__(self, mesh: mesh_lib.Mesh):
self.mesh = mesh
def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding:
dim_shardings = [SdyDimSharding(axes=[], is_closed=False)
for _ in range(ndim)]
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
class UnspecifiedValue:
def __repr__(self):
return "UnspecifiedValue"
UNSPECIFIED = UnspecifiedValue()
MeshAxisName = Any
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = collections.OrderedDict[MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]
@use_cpp_class(xc.NamedSharding)
class NamedSharding(JSharding.Sharding):
r"""A :class:`NamedSharding` expresses sharding using named axes.
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
:class:`PartitionSpec` which describes how to shard an array across that
mesh.
A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.
A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
a mesh axis, or a tuple of mesh axes. Each element describes how an input
dimension is partitioned across zero or more mesh dimensions. For example,
``PartitionSpec('x', 'y')`` says that the first dimension of data
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
across ``y`` axis of the mesh.
The Distributed arrays and automatic parallelization
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
tutorial has more details and diagrams that explain how
:class:`Mesh` and :class:`PartitionSpec` are used.
Args:
mesh: A :class:`jax.sharding.Mesh` object.
spec: A :class:`jax.sharding.PartitionSpec` object.
Examples:
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
"""
mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh
spec: PartitionSpec
_memory_kind: str | None
_manual_axes: frozenset[MeshAxisName]
_logical_device_ids: tuple[int, ...] | None
@use_cpp_method()
def __init__(
self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *,
memory_kind: str | None = None, _manual_axes=frozenset(),
_logical_device_ids=None):
self.mesh = mesh
self.spec = spec
self._memory_kind = memory_kind
self._manual_axes = _manual_axes
self._logical_device_ids = _logical_device_ids
check_pspec(self.mesh, self.spec, self._manual_axes)
def __repr__(self):
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
ldi = ('' if self._logical_device_ids is None else
f', logical_device_ids={self._logical_device_ids}')
mesh_repr = f"{str(self.mesh)}"
return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})'
def __reduce__(self):
return (type(self), (self.mesh, self.spec),
{'memory_kind': self.memory_kind,
'_manual_axes': self._manual_axes,
'_logical_device_ids': self._logical_device_ids})
@property
def memory_kind(self) -> str | None:
return self._memory_kind
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash(
(self.mesh, self.memory_kind, self.spec, self._manual_axes,
self._logical_device_ids))
return self._hash
def __eq__(self, other):
if not isinstance(other, NamedSharding):
return False
if self is other:
return True
if (self.spec != other.spec
or self.memory_kind != other.memory_kind
or self._manual_axes != other._manual_axes
or self._logical_device_ids != other._logical_device_ids):
return False
return self.mesh is other.mesh or self.mesh == other.mesh
def check_compatible_aval(self, aval_shape: Shape) -> None:
if len(aval_shape) < len(self.spec):
extra_msg = (' For scalars the PartitionSpec should be P()'
if len(aval_shape) == 0 else '')
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(self.spec)}, but was applied to a value of rank "
f"{len(aval_shape)}.{extra_msg}")
@property
def num_devices(self) -> int:
return self.mesh.size
@property
def device_set(self) -> set[Device]:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError(
'device_set is not implemented for `jax.sharding.AbstractMesh`.')
return self.mesh._flat_devices_set
@property
def _device_assignment(self) -> XLADeviceAssignment:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError('_device_assignment is not implemented for'
' `jax.sharding.AbstractMesh`.')
return self.mesh._flat_devices_tuple
@property
def is_fully_addressable(self) -> bool:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError('is_fully_addressable is not implemented for '
'`jax.sharding.AbstractMesh`.')
# Speed up `is_fully_addressable` since there is a high chance that the
# mesh across multiple NamedSharding objects will be the same.
if config.enable_empty_arrays.value:
client = self._internal_device_list[0].client
return (len(self.mesh._process_indices) == 1 and
next(iter(self.mesh._process_indices)) ==
xb.process_index(client))
return not self.mesh.is_multi_process
@property
def _is_concrete(self) -> bool:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
return False
return True
@property
def addressable_devices(self) -> set[Device]:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError('addressable_devices is not implemented for '
'`jax.sharding.AbstractMesh`.')
# Override addressable devices because there is a high chance that the mesh
# across multiple NamedSharding objects will be the same.
return self.mesh._local_devices_set
@functools.cached_property
def is_fully_replicated(self) -> bool:
if self.mesh.size == 1:
return True
array_mapping = get_array_mapping(self.spec)
mesh_shape = self.mesh.shape
num_partitions = 1
for name in array_mapping: # type: ignore
num_partitions *= mesh_shape[name]
return num_partitions == 1
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
if not isinstance(spec, PartitionSpec):
spec = PartitionSpec(*spec)
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
dim_shardings = [SdyDimSharding(axes=[], is_closed=True)
for _ in range(num_dimensions)]
for i, dim_spec in enumerate(self.spec):
if dim_spec is PartitionSpec.UNCONSTRAINED:
dim_shardings[i].is_closed = False
elif dim_spec is None:
# Already empty and closed sharding.
pass
else:
dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,)
dim_shardings[i].axes = dim_spec
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings,
self._logical_device_ids)
def get_array_mapping(
axis_resources: PartitionSpec | AUTO | UnspecifiedValue
) -> ArrayMappingOrAutoOrUnspecified:
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
return axis_resources
d = collections.OrderedDict()
for i, axes in enumerate(axis_resources):
if axes is None or axes is PartitionSpec.UNCONSTRAINED:
continue
axes = axes if isinstance(axes, tuple) else (axes,)
for axis in axes:
d[axis] = i
return d
@dataclasses.dataclass
class SdyDimSharding:
axes: Sequence[str]
is_closed: bool
priority: int | None = None
def build(self) -> sdy.DimensionShardingAttr:
return sdy.DimensionShardingAttr.get(
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
is_closed=self.is_closed,
priority=self.priority)
def __repr__(self):
return f'SdyDimSharding({self._custom_repr()})'
def _custom_repr(self):
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
open_repr = ''
if not self.is_closed:
open_repr = ', ?' if self.axes else '?'
priority_repr = '' if self.priority is None else f'p{self.priority}'
return f'{{{axes_repr}{open_repr}}}{priority_repr}'
@dataclasses.dataclass
class SdyArraySharding:
mesh_shape: tuple[tuple[str, int], ...] | None
dimension_shardings: Sequence[SdyDimSharding]
logical_device_ids: tuple[int, ...] | None = None
replicated_axes: tuple[str, ...] = ()
def build(self) -> sdy.TensorShardingAttr:
if self.mesh_shape is None:
mesh_attr = sdy.MeshAttr.get([])
else:
ldi = ([] if self.logical_device_ids is None else
list(self.logical_device_ids))
mesh_attr = sdy.MeshAttr.get(
[sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape],
ldi)
return sdy.TensorShardingAttr.get(
mesh_attr,
[dim_sharding.build() for dim_sharding in self.dimension_shardings],
replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes])
def __repr__(self):
dim_sharding_repr = ', '.join(
d._custom_repr() for d in self.dimension_shardings)
device_id_repr = (f', device_ids={self.logical_device_ids}'
if self.logical_device_ids is not None else '')
rar = (f', replicated_axes={self.replicated_axes}'
if self.replicated_axes else '')
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
# TODO(yashkatariya): Remove this after jax 0.5.2 release
class ParsedPartitionSpec:
__slots__ = ('_user_spec', 'partitions')
_user_spec: PartitionSpec | None
partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...]
def __init__(self, user_spec, partitions):
self._user_spec = user_spec
assert None not in partitions, partitions
self.partitions = tuple(partitions)
def get_partition_spec(self) -> PartitionSpec:
if isinstance(self._user_spec, PartitionSpec):
return self._user_spec
else:
return get_single_pspec(self)
def insert_axis_partitions(self, dim, val):
parts = self.partitions
too_short = dim - len(parts)
if too_short > 0:
parts += ((),) * too_short
new_partitions = tuple_insert(parts, dim, val)
return ParsedPartitionSpec(None, new_partitions)
@classmethod
def from_user_input(
cls,
entry: PartitionSpec | None,
arg_name: str,
allow_unconstrained_dims: bool = False,
) -> ParsedPartitionSpec:
if entry is None:
return cls(entry, ())
if not isinstance(entry, PartitionSpec):
raise TypeError(f"{arg_name} are expected to be "
f"PartitionSpec instances or None, but got {entry}")
axis_specs = []
for axis_spec in entry:
if axis_spec is None:
axis_spec = ()
elif isinstance(axis_spec, (list, tuple)):
axis_spec = tuple(axis_spec)
elif axis_spec is PartitionSpec.UNCONSTRAINED:
if not allow_unconstrained_dims:
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
axis_spec = PartitionSpec.UNCONSTRAINED
else:
axis_spec = (axis_spec,)
axis_specs.append(axis_spec)
new_entry = PartitionSpec(
*[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry])
return cls(new_entry, axis_specs)
def __hash__(self):
return hash(self.partitions)
def __eq__(self, other):
if not isinstance(other, ParsedPartitionSpec):
return False
return self.partitions == other.partitions
def __len__(self):
return len(self.partitions)
def __getitem__(self, i):
return self.partitions[i]
def __iter__(self):
return iter(self.partitions)
def __repr__(self):
return f"ParsedPartitionSpec(partitions={self.partitions})"
@cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
mesh_shape = self.mesh.shape
array_mapping = get_array_mapping(self.spec)
mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)}
special_axes = {}
mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items()
if t == mesh_lib.AxisType.Manual}
manual_axes = self._manual_axes.union(mesh_manual_axes)
if manual_axes:
axis_names = self.mesh.axis_names
for manual_axis in manual_axes:
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
replicated_mesh_axes = []
for i, (axis_name, axis_val) in enumerate(mesh_shape.items()):
if axis_name not in array_mapping: # type: ignore
replicated_mesh_axes.append((i, axis_val))
if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes:
return xc.HloSharding.replicate()
mesh_permutation = []
new_mesh_shape = [1] * num_dimensions
for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): # type: ignore
new_mesh_shape[pos] *= mesh_shape[name]
mesh_permutation.append(mesh_axis_pos[name])
last_tile_dims = []
if replicated_mesh_axes:
axes_by_type = collections.defaultdict(list)
size_by_type = collections.defaultdict(lambda: 1) # type: ignore
assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys()))
for i, size in replicated_mesh_axes:
ty = special_axes.get(i, xc.OpSharding.Type.REPLICATED)
axes_by_type[ty].append(i)
size_by_type[ty] *= size
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
last_tile_dims.append(ty)
new_mesh_shape.append(size_by_type[ty])
mesh_permutation.extend(axes)
# Explanation of the parameters of `HloSharding.iota_tile`.
# This is the HloShardingV2 format:
# * dims: How many ways each dimension is sharded.
# Replicated/Manual dims are added added at the end
# * reshape_dims: This is the just the shape of the mesh.
# * transpose_perm: This is the order in which mesh axes in PartitionSpec
# appear relative to mesh.axis_names order.
# * subgroup_types: List of type of OpSharding. Type can be REPLICATED and MANUAL.
# Let's see an example:
# Consider input_shape=(8, 4, 2, 2), mesh={'a': 2, 'b': 2, 'c': 2, 'd': 2}
# and partition_spec=P(None, ('d', 'b'), 'c').
# Arguments to iota_tile will be:
# dims = [1, 4, 2, 1, 2] # 'a' is replicated hence `2` is at the end.
# reshape_dims = [2, 2, 2, 2]
# transpose_perm = [3, 1, 2, 0] # 'a' is replicated hence 0 is at the end
# subgroup_types = [xc.OpSharding.Type.REPLICATED]
dims = new_mesh_shape
reshape_dims = self.mesh.axis_sizes
if self._logical_device_ids is None:
return xc.HloSharding.iota_tile(
dims=dims, reshape_dims=reshape_dims, transpose_perm=mesh_permutation,
subgroup_types=last_tile_dims)
else:
return xc.HloSharding.subgroup_with_device_ordering(
np.asarray(self._logical_device_ids)
.reshape(dims).reshape(reshape_dims).transpose(mesh_permutation)
.reshape(dims), subgroup_types=last_tile_dims)
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
if not array_mapping:
return PartitionSpec()
max_index = -1
reverse_map = collections.defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
if index > max_index:
max_index = index
partitions = []
for i in range(max_index + 1):
axis = reverse_map[i]
if axis:
partitions.append(axis[0] if len(axis) == 1 else tuple(axis))
else:
partitions.append(None)
return PartitionSpec(*partitions)
get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore
# TODO(yashkatariya): Remove this after jax 0.5.2 release
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
if parsed_pspec is None:
spec = PartitionSpec() if spec is None else spec
parsed_pspec = ParsedPartitionSpec.from_user_input(
spec, "NamedSharding spec", allow_unconstrained_dims=True)
_check_unique_resources(parsed_pspec, "NamedSharding spec", mesh)
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
return parsed_pspec
def check_pspec(mesh, spec, _manual_axes=frozenset()):
_check_unique_resources(spec, "NamedSharding spec", mesh)
_check_mesh_resource_axis(mesh, spec, _manual_axes)
class DuplicateSpecError(Exception):
def __init__(self, message, mesh, pspec):
super().__init__(message)
self.message = message
self.mesh = mesh
self.pspec = pspec
def __str__(self):
return f"{self.message}"
def _check_unique_resources(
pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None,
) -> None:
resource_counts: dict[MeshAxisName, int] = {}
duplicate = False
pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec)
else pspec)
for d in pspec:
if d is PartitionSpec.UNCONSTRAINED or d is None:
continue
d = d if isinstance(d, tuple) else (d,)
for resource in d:
count = resource_counts.get(resource, 0)
if count > 0:
duplicate = True
resource_counts[resource] = count + 1
if duplicate:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
raise DuplicateSpecError(
message=(
f'A single {arg_name} specification can map every mesh axis to at'
f' most one positional dimension, but {pspec} has duplicate entries'
f' for {mesh_lib.show_axes(multiple_uses)}'),
mesh=mesh, pspec=pspec)
@cache(max_size=128, trace_context_in_key=False)
def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec)
else pspec)
for p in pspec:
if p is PartitionSpec.UNCONSTRAINED or p is None:
continue
p = p if isinstance(p, tuple) else (p,)
for r in p:
if r not in mesh.shape:
raise ValueError(
f"Resource axis: {r} of {pspec} "
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
if r in _manual_axes:
raise ValueError(
f"Axis: {r} of {pspec} "
f"is also found in manual_axes: {_manual_axes}.") from None
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
raise ValueError(
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {pspec}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
if (mesh_lib.AxisType.Auto not in mesh._axis_types_dict and
PartitionSpec.UNCONSTRAINED in pspec):
raise ValueError(
f'{pspec} cannot contain'
' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
f' axis_types: {mesh._axis_types_dict}')