rocm_jax/jax/_src/sharding_impls.py
Yash Katariya 49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00

1850 lines
68 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
from collections import OrderedDict
from collections.abc import Mapping, Sequence
import dataclasses
import functools
import itertools
import math
from typing import Any, NamedTuple, Union, cast
from jax._src import core
from jax._src import config
from jax._src import mesh as mesh_lib
from jax._src import sharding as jsharding
from jax._src import sharding_specs
from jax._src import tree_util
from jax._src import util
from jax._src import source_info_util
from jax._src import xla_bridge
from jax._src import mesh_utils
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import sdy
from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
from jax._src.partition_spec import PartitionSpec
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
import numpy as np
Shape = tuple[int, ...]
Device = xc.Device
Index = tuple[slice, ...]
XLADeviceAssignment = tuple[Device, ...]
# TODO(yashkatariya): Remove this after 3 months of deprecation.
XLACompatibleSharding = jsharding.Sharding
@dataclasses.dataclass(frozen=True)
class TransferToMemoryKind:
memory_kind: str
@util.cache(max_size=128, trace_context_in_key=False)
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
for p in parsed_pspec:
if p is not None:
for r in p:
if r not in mesh.shape:
raise ValueError(
f"Resource axis: {r} of {parsed_pspec.get_partition_spec()} "
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
if r in _manual_axes:
raise ValueError(
f"Axis: {r} of {parsed_pspec.get_partition_spec()} "
f"is also found in manual_axes: {_manual_axes}.") from None
@util.cache(max_size=128, trace_context_in_key=False)
def _check_axis_type_consistency(mesh, parsed_pspec):
for p in parsed_pspec:
if p is not None:
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
raise ValueError(
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
if mesh_lib.AxisTypes.Hidden not in mesh.axis_types and None in parsed_pspec:
raise ValueError(
f'PartitionSpec {parsed_pspec.get_partition_spec()} cannot contain'
' `P.UNCONSTRAINED` when no mesh axis_types are `Hidden`. Got mesh'
f' axis_types: {mesh.axis_types}')
def hashed_index(x) -> int:
# This works for both `pjit` indices and `pmap` indices (which might
# have an integer instead of a slice).
assert all(v.step is None for v in x if isinstance(v, slice))
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
@util.cache(max_size=4096, trace_context_in_key=False)
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
try:
device_indices_map_fn = sharding.devices_indices_map
except AttributeError:
raise ValueError(
f'Cannot calculate replica ids from sharding: {sharding}. Please '
'create a device to index mapping for your sharding from which replica '
'ids will be calculated.') from None
index_to_replica: dict[int, int] = collections.Counter()
out = {}
for device, index in device_indices_map_fn(global_shape).items():
h_index = hashed_index(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
out[device] = replica_id
return out
@dataclasses.dataclass
class SdyDimSharding:
axes: Sequence[str]
is_closed: bool
priority: int | None = None
def build(self) -> sdy.DimensionShardingAttr:
return sdy.DimensionShardingAttr.get(
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
is_closed=self.is_closed,
priority=self.priority)
def __repr__(self):
return f'SdyDimSharding({self._custom_repr()})'
def _custom_repr(self):
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
open_repr = ''
if not self.is_closed:
open_repr = ', ?' if self.axes else '?'
priority_repr = '' if self.priority is None else f'p{self.priority}'
return f'{{{axes_repr}{open_repr}}}{priority_repr}'
@dataclasses.dataclass
class SdyArraySharding:
mesh_shape: tuple[tuple[str, int], ...] | None
dimension_shardings: Sequence[SdyDimSharding]
logical_device_ids: tuple[int, ...] | None = None
replicated_axes: tuple[str, ...] = ()
def build(self) -> sdy.TensorShardingAttr:
if self.mesh_shape is None:
mesh_attr = sdy.MeshAttr.get([])
else:
ldi = ([] if self.logical_device_ids is None else
list(self.logical_device_ids))
mesh_attr = sdy.MeshAttr.get(
[sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape],
ldi)
return sdy.TensorShardingAttr.get(
mesh_attr,
[dim_sharding.build() for dim_sharding in self.dimension_shardings],
replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes])
def __repr__(self):
dim_sharding_repr = ', '.join(
d._custom_repr() for d in self.dimension_shardings)
device_id_repr = (f', device_ids={self.logical_device_ids}'
if self.logical_device_ids is not None else '')
rar = (f', replicated_axes={self.replicated_axes}'
if self.replicated_axes else '')
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
@dataclasses.dataclass
class SdyArrayShardingList:
shardings: Sequence[SdyArraySharding]
def build(self) -> sdy.TensorShardingPerValueAttr:
return sdy.TensorShardingPerValueAttr.get(
[sharding.build() for sharding in self.shardings])
@util.cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
mesh_shape = self.mesh.shape
array_mapping = get_array_mapping(self._parsed_pspec)
mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)}
special_axes = {}
mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items()
if t == mesh_lib.AxisTypes.Collective}
manual_axes = self._manual_axes.union(mesh_manual_axes)
if manual_axes:
axis_names = self.mesh.axis_names
for manual_axis in manual_axes:
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
replicated_mesh_axes = []
for i, (axis_name, axis_val) in enumerate(mesh_shape.items()):
if axis_name not in array_mapping: # type: ignore
replicated_mesh_axes.append((i, axis_val))
if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes:
return xc.HloSharding.replicate()
mesh_permutation = []
new_mesh_shape = [1] * num_dimensions
for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): # type: ignore
new_mesh_shape[pos] *= mesh_shape[name]
mesh_permutation.append(mesh_axis_pos[name])
last_tile_dims = []
if replicated_mesh_axes:
axes_by_type = collections.defaultdict(list)
size_by_type = collections.defaultdict(lambda: 1) # type: ignore
assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys()))
for i, size in replicated_mesh_axes:
ty = special_axes.get(i, xc.OpSharding.Type.REPLICATED)
axes_by_type[ty].append(i)
size_by_type[ty] *= size
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
last_tile_dims.append(ty)
new_mesh_shape.append(size_by_type[ty])
mesh_permutation.extend(axes)
# Explanation of the parameters of `HloSharding.iota_tile`.
# This is the HloShardingV2 format:
# * dims: How many ways each dimension is sharded.
# Replicated/Manual dims are added added at the end
# * reshape_dims: This is the just the shape of the mesh.
# * transpose_perm: This is the order in which mesh axes in PartitionSpec
# appear relative to mesh.axis_names order.
# * subgroup_types: List of type of OpSharding. Type can be REPLICATED and MANUAL.
# Let's see an example:
# Consider input_shape=(8, 4, 2, 2), mesh={'a': 2, 'b': 2, 'c': 2, 'd': 2}
# and partition_spec=P(None, ('d', 'b'), 'c').
# Arguments to iota_tile will be:
# dims = [1, 4, 2, 1, 2] # 'a' is replicated hence `2` is at the end.
# reshape_dims = [2, 2, 2, 2]
# transpose_perm = [3, 1, 2, 0] # 'a' is replicated hence 0 is at the end
# subgroup_types = [xc.OpSharding.Type.REPLICATED]
dims = new_mesh_shape
reshape_dims = self.mesh.axis_sizes
if self._logical_device_ids is None:
return xc.HloSharding.iota_tile(
dims=dims, reshape_dims=reshape_dims, transpose_perm=mesh_permutation,
subgroup_types=last_tile_dims)
else:
return xc.HloSharding.subgroup_with_device_ordering(
np.asarray(self._logical_device_ids)
.reshape(dims).reshape(reshape_dims).transpose(mesh_permutation)
.reshape(dims), subgroup_types=last_tile_dims)
@use_cpp_class(xc.NamedSharding)
class NamedSharding(jsharding.Sharding):
r"""A :class:`NamedSharding` expresses sharding using named axes.
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
:class:`PartitionSpec` which describes how to shard an array across that
mesh.
A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.
A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
a mesh axis, or a tuple of mesh axes. Each element describes how an input
dimension is partitioned across zero or more mesh dimensions. For example,
``PartitionSpec('x', 'y')`` says that the first dimension of data
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
across ``y`` axis of the mesh.
The Distributed arrays and automatic parallelization
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
tutorial has more details and diagrams that explain how
:class:`Mesh` and :class:`PartitionSpec` are used.
Args:
mesh: A :class:`jax.sharding.Mesh` object.
spec: A :class:`jax.sharding.PartitionSpec` object.
Examples:
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
"""
mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh
spec: PartitionSpec
_memory_kind: str | None
_parsed_pspec: ParsedPartitionSpec
_manual_axes: frozenset[MeshAxisName]
_logical_device_ids: tuple[int, ...] | None
@use_cpp_method()
def __init__(
self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *,
memory_kind: str | None = None, _parsed_pspec=None,
_manual_axes=frozenset(), _logical_device_ids=None):
self.mesh = mesh
self.spec = spec
self._memory_kind = memory_kind
self._manual_axes = _manual_axes
self._logical_device_ids = _logical_device_ids
self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec)
def __repr__(self):
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
ldi = ('' if self._logical_device_ids is None else
f', logical_device_ids={self._logical_device_ids}')
if isinstance(self.mesh, mesh_lib.AbstractMesh):
mesh_repr = f"{self.mesh}"
else:
nv_str = ", ".join(f"'{n}': {v}" for n, v in self.mesh.shape.items())
mesh_repr = f"Mesh({nv_str})"
return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})'
def __reduce__(self):
return (type(self), (self.mesh, self.spec),
{'memory_kind': self.memory_kind,
'_manual_axes': self._manual_axes,
'_logical_device_ids': self._logical_device_ids})
@property
def memory_kind(self) -> str | None:
return self._memory_kind
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash(
(self.mesh, self.memory_kind, self._parsed_pspec, self._manual_axes,
self._logical_device_ids))
return self._hash
def __eq__(self, other):
if not isinstance(other, NamedSharding):
return False
if self is other:
return True
if (self._parsed_pspec != other._parsed_pspec
or self.memory_kind != other.memory_kind
or self._manual_axes != other._manual_axes
or self._logical_device_ids != other._logical_device_ids):
return False
return self.mesh is other.mesh or self.mesh == other.mesh
def check_compatible_aval(self, aval_shape: Shape) -> None:
assert self._parsed_pspec is not None
if len(aval_shape) < len(self._parsed_pspec):
extra_msg = (' For scalars the PartitionSpec should be P()'
if len(aval_shape) == 0 else '')
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(self._parsed_pspec)}, but was applied to a value of rank "
f"{len(aval_shape)}.{extra_msg}")
@classmethod
def _from_parsed_pspec(
cls, mesh, parsed_pspec, *, memory_kind=None, _manual_axes=frozenset(),
_logical_device_ids=None,
):
return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
_manual_axes=_manual_axes,
_logical_device_ids=_logical_device_ids)
@property
def num_devices(self) -> int:
return self.mesh.size
@property
def device_set(self) -> set[Device]:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError(
'device_set is not implemented for `jax.sharding.AbstractMesh`.')
return self.mesh._flat_devices_set
@property
def _device_assignment(self) -> XLADeviceAssignment:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError('_device_assignment is not implemented for'
' `jax.sharding.AbstractMesh`.')
return self.mesh._flat_devices_tuple
@property
def is_fully_addressable(self) -> bool:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError('is_fully_addressable is not implemented for '
'`jax.sharding.AbstractMesh`.')
# Speed up `is_fully_addressable` since there is a high chance that the
# mesh across multiple NamedSharding objects will be the same.
return not self.mesh.is_multi_process
@property
def addressable_devices(self) -> set[Device]:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
raise ValueError('addressable_devices is not implemented for '
'`jax.sharding.AbstractMesh`.')
# Override addressable devices because there is a high chance that the mesh
# across multiple NamedSharding objects will be the same.
return self.mesh._local_devices_set
@functools.cached_property
def is_fully_replicated(self) -> bool:
if self.mesh.size == 1:
return True
array_mapping = cast(ParsedPartitionSpec, get_array_mapping(self._parsed_pspec))
mesh_shape = self.mesh.shape
num_partitions = 1
for name in array_mapping:
num_partitions *= mesh_shape[name]
return num_partitions == 1
def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
if not isinstance(spec, PartitionSpec):
spec = PartitionSpec(*spec)
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
dim_shardings = [SdyDimSharding(axes=[], is_closed=True)
for _ in range(num_dimensions)]
for i, dim_spec in enumerate(self._parsed_pspec):
if dim_spec is None:
dim_shardings[i].is_closed = False
elif not dim_spec:
# Already empty and closed sharding.
pass
else:
dim_shardings[i].axes = dim_spec
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings,
self._logical_device_ids)
# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
if mesh._any_axis_hidden:
dim_shardings, used_axes = [], [] # type: ignore
for d in sdy_sharding.dimension_shardings:
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open?
dim_shardings.append(SdyDimSharding(axes=[], is_closed=False)
if not d.axes and d.is_closed else d)
used_axes.extend(d.axes)
remaining_axes = set(mesh.axis_names) - set(used_axes)
replicated_axes = tuple(r for r in remaining_axes
if mesh._name_to_type[r] == mesh_lib.AxisTypes.Visible)
return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings,
sdy_sharding.logical_device_ids, replicated_axes)
return sdy_sharding
@util.cache(max_size=128, trace_context_in_key=False)
def get_replicated_hlo_sharding():
return xc.HloSharding.replicate()
@use_cpp_class(xc.SingleDeviceSharding)
class SingleDeviceSharding(jsharding.Sharding):
"""A :class:`Sharding` that places its data on a single device.
Args:
device: A single :py:class:`Device`.
Examples:
>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
... jax.devices()[0])
"""
_device: Device
_memory_kind: str | None
@use_cpp_method()
def __init__(self, device: Device, *, memory_kind: str | None = None):
self._device = device
self._memory_kind = memory_kind
def __reduce__(self):
return type(self), (self._device,), {'memory_kind': self._memory_kind}
def __repr__(self):
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f"SingleDeviceSharding(device={self._device!r}{mem})"
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._device, self.memory_kind))
return self._hash
def __eq__(self, other):
if not isinstance(other, SingleDeviceSharding):
return False
if self is other:
return True
return (self._device == other._device and
self.memory_kind == other.memory_kind)
@property
def num_devices(self) -> int:
return len(self.device_set)
@property
def device_set(self) -> set[Device]:
return {self._device}
@property
def memory_kind(self) -> str | None:
return self._memory_kind
def with_memory_kind(self, kind: str) -> SingleDeviceSharding:
return SingleDeviceSharding(self._device, memory_kind=kind)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
return {self._device: (slice(None),) * len(global_shape)}
@property
def _device_assignment(self) -> XLADeviceAssignment:
return (self._device,)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return get_replicated_hlo_sharding()
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
for _ in range(num_dimensions)]
return SdyArraySharding(None, sdy_dim_sharding)
@property
def is_fully_replicated(self) -> bool:
return True
@property
def is_fully_addressable(self) -> bool:
return True
@util.cache(max_size=4096, trace_context_in_key=False)
def pmap_sharding_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index]:
self.shard_shape(global_shape) # raises a good error message
indices = sharding_specs.spec_to_indices(global_shape, self.sharding_spec)
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
@use_cpp_class(xc.PmapSharding)
class PmapSharding(jsharding.Sharding):
"""Describes a sharding used by :func:`jax.pmap`."""
devices: np.ndarray
sharding_spec: sharding_specs.ShardingSpec
_internal_device_list: xc.DeviceList
@use_cpp_method()
def __init__(self, devices: Sequence[Device] | np.ndarray,
sharding_spec: sharding_specs.ShardingSpec):
self.devices = np.asarray(devices)
# The sharding spec should be pmap's sharding spec.
self.sharding_spec = sharding_spec
def __reduce__(self):
return (type(self), (self.devices, self.sharding_spec),
{'memory_kind': self.memory_kind})
def __eq__(self, other):
if not isinstance(other, PmapSharding):
return False
if self is other:
return True
return (self.sharding_spec == other.sharding_spec and
self.devices.shape == other.devices.shape and
self._internal_device_list == other._internal_device_list)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._internal_device_list, self.sharding_spec))
return self._hash
def __str__(self):
device_ids = [d.id for d in self.devices.flat]
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
f'{device_ids=}, '
f'device_platform={self.devices.flat[0].platform.upper()}, '
f'device_shape={self.devices.shape})')
def __repr__(self):
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
f'devices={self.devices})')
def is_equivalent_to(self: PmapSharding, other: PmapSharding, # type: ignore
ndim: int) -> bool:
return self == other
# TODO(yashkatariya): Expose `sharded_dim_size` in the API if required.
@classmethod
def default(cls, shape: Shape, sharded_dim: int | None = 0,
devices: Sequence[xc.Device] | None = None) -> PmapSharding:
"""Creates a :class:`PmapSharding` which matches the default placement
used by :func:`jax.pmap`.
Args:
shape: The shape of the input array.
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
devices: Optional sequence of devices to use. If omitted, the implicit
device order used by pmap is used, which is the order of
:func:`jax.local_devices`.
"""
if sharded_dim is None:
if devices is None:
raise ValueError("One of sharded_dim or devices must be set.")
nrep = len(devices)
return cls(np.array(devices),
sharding_specs.pmap_sharding_spec(nrep, nrep, shape, None))
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
sharding_spec = sharding_specs.create_pmap_sharding_spec(
tuple(shape), sharded_dim)
num_ways_sharded = None
for s in sharding_spec.sharding:
if isinstance(s, sharding_specs.Unstacked):
assert num_ways_sharded is None
num_ways_sharded = s.size
elif isinstance(s, sharding_specs.Chunked):
assert num_ways_sharded is None
if len(s.chunks) == 1:
num_ways_sharded = s.chunks[0]
else:
raise NotImplementedError(
'Multiple chunks in Chunked dimension not supported.')
if devices is None:
pmap_devices: np.ndarray = np.array(
xla_bridge.local_devices()[:num_ways_sharded])
else:
pmap_devices = np.array(devices)
return cls(pmap_devices, sharding_spec)
@property
def num_devices(self) -> int:
return len(self.device_set)
@functools.cached_property
def device_set(self) -> set[Device]:
return set(self.devices.flat)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
return pmap_sharding_devices_indices_map(self, global_shape)
@functools.cached_property
def _device_assignment(self) -> XLADeviceAssignment:
return tuple(self.devices.flat)
@property
def memory_kind(self) -> str | None:
try:
return self._internal_device_list.default_memory_kind
except:
return None
def with_memory_kind(self, kind: str):
raise NotImplementedError("pmap does not support memories.")
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
raise NotImplementedError("pmap doesn't use SdyArraySharding.")
@functools.cached_property
def is_fully_replicated(self) -> bool:
for s in self.sharding_spec.sharding:
if isinstance(s, (sharding_specs.Unstacked, sharding_specs.Chunked)):
return False
return True
@functools.cached_property
def is_fully_addressable(self) -> bool:
return self._internal_device_list.is_fully_addressable
def shard_shape(self, global_shape: Shape) -> Shape:
sharded_dim = None
sharded_dim_size = None
for i, s in enumerate(self.sharding_spec.sharding):
if isinstance(s, sharding_specs.Unstacked):
sharded_dim = i
sharded_dim_size = s.size
sharded_shape = util.tuple_delete(global_shape, sharded_dim)
break
elif isinstance(s, sharding_specs.Chunked):
sharded_dim = i
assert len(s.chunks) == 1, s.chunks
sharded_dim_size = s.chunks[0]
sharded_shape = util.tuple_update(global_shape, sharded_dim, 1)
break
if sharded_dim is None:
return global_shape
if global_shape[sharded_dim] != sharded_dim_size:
raise ValueError(
f'The sharded dimension must be equal to the number of '
f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} '
f'with value {global_shape[sharded_dim]} in shape {global_shape} and '
f'the number of devices={len(self._device_assignment)}')
return sharded_shape
def _op_sharding_to_pos_sharding(
op_sharding: xc.OpSharding | xc.HloSharding,
device_assignment: Sequence[xc.Device],
memory_kind: str | None = None) -> PositionalSharding:
if isinstance(op_sharding, xc.OpSharding):
op_sharding = xc.HloSharding.from_proto(op_sharding)
if op_sharding.is_replicated():
return PositionalSharding(
device_assignment, memory_kind=memory_kind).replicate()
if len(op_sharding.subgroup_types()) > 1:
raise NotImplementedError(
'Unhandled HloSharding type. Please open a bug report!'
)
name = device_assignment[0].platform.upper()
ids = np.array(
[DeviceIdSet(name, i) for i in op_sharding.tile_assignment_devices()]
)
p = PositionalSharding._remake(tuple(device_assignment), ids,
memory_kind=memory_kind)
p = p.reshape(op_sharding.tile_assignment_dimensions())
if op_sharding.replicate_on_last_tile_dim():
p = p.replicate(-1, keepdims=False)
return p
@util.cache(max_size=4096, trace_context_in_key=False)
def _positional_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
if self.shape == (1,) * self.ndim:
return get_replicated_hlo_sharding()
pbuf = xc.OpSharding()
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
set_size, = {len(device_set) for device_set in self._ids.flat}
pbuf.type = xc.OpSharding.Type.OTHER
if set_size > 1:
pbuf.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
pbuf.tile_assignment_dimensions = (*shape, set_size)
else:
pbuf.tile_assignment_dimensions = shape
pbuf.tile_assignment_devices = [i for ids in self._ids.flat for i in ids]
product_of_dims = math.prod(pbuf.tile_assignment_dimensions)
num_devices = len(pbuf.tile_assignment_devices)
assert product_of_dims == num_devices, (product_of_dims, num_devices)
return xc.HloSharding.from_proto(pbuf)
class PositionalSharding(jsharding.Sharding):
_devices: tuple[xc.Device, ...]
_memory_kind: str | None
_ids: np.ndarray # dtype DeviceIdSet
def __init__(self, devices: Sequence[xc.Device] | np.ndarray,
*, memory_kind: str | None = None):
super().__init__()
if not isinstance(devices, np.ndarray):
devices = np.array(devices, dtype='object')
if not devices.size:
raise ValueError(f"{self.__class__.__name__}.__init__ requires at least "
f"one device, got {devices}")
self._devices = tuple(devices.flat)
self._memory_kind = memory_kind
name = self._devices[0].platform.upper()
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
dtype='object').reshape(devices.shape)
self._internal_device_list = xc.DeviceList(self._devices)
self._memory_kind = xc.check_and_canonicalize_memory_kind(
self._memory_kind, self._internal_device_list)
@property
def shape(self):
return self._ids.shape
@property
def ndim(self):
return self._ids.ndim
def __repr__(self) -> str:
cls_name = self.__class__.__name__
ids = self._ids.copy()
platform_name = self._devices[0].platform.upper()
for idx, x in np.ndenumerate(ids):
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) # type: ignore # numpy 2.2
body = np.array2string(ids, prefix=cls_name + '(', suffix=')',
max_line_width=100)
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f'{cls_name}({body}{mem}, shape={self.shape})'
def reshape(self, *shape) -> PositionalSharding:
return self._remake(self._devices, self._ids.reshape(*shape),
memory_kind=self.memory_kind)
def transpose(self, *axes) -> PositionalSharding:
return self._remake(self._devices, self._ids.transpose(*axes),
memory_kind=self.memory_kind)
T = property(transpose)
def replicate(self, axis=None, keepdims=True) -> PositionalSharding:
new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union
return self._remake(self._devices, new_ids,
memory_kind=self.memory_kind)
def check_compatible_aval(self, aval_shape: Shape) -> None:
if len(aval_shape) != len(self.shape) and not self.is_fully_replicated:
raise ValueError(
f"Sharding {self} is only valid for values of rank "
f"{len(self.shape)}, but was applied to a value of rank "
f"{len(aval_shape)}")
@classmethod
def _remake(
cls, devices: tuple[xc.Device, ...], ids: np.ndarray,
*, memory_kind: str | None = None) -> PositionalSharding:
sharding = cls(devices, memory_kind=memory_kind)
sharding._ids = ids
return sharding
# Hashable
def __hash__(self) -> int:
if not hasattr(self, '_hash'):
self._hash = hash((self._internal_device_list, self.memory_kind))
return self._hash
def __eq__(self, other) -> bool:
if not isinstance(other, PositionalSharding):
return False
if self is other:
return True
all_ids_equal = np.array_equal(self._ids,other._ids)
mem_kind_equal = self.memory_kind == other.memory_kind
if self._devices is other._devices and mem_kind_equal and all_ids_equal:
return True
return (mem_kind_equal and all_ids_equal and
self._internal_device_list == other._internal_device_list)
# Sharding interface
@property
def num_devices(self) -> int:
return len(self.device_set)
@functools.cached_property
def device_set(self) -> set[xc.Device]:
return set(self._devices)
@property
def memory_kind(self) -> str | None:
return self._memory_kind
def with_memory_kind(self, kind: str) -> PositionalSharding:
return PositionalSharding(self._devices, memory_kind=kind)
@functools.cached_property
def is_fully_replicated(self) -> bool:
return self.shape == (1,) * self.ndim
# jsharding.Sharding interface
@property
def _device_assignment(self) -> XLADeviceAssignment:
return self._devices
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions)
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
raise NotImplementedError(
"PositionalSharding can't be converted to an SdyArraySharding.")
@functools.cached_property
def is_fully_addressable(self) -> bool:
return self._internal_device_list.is_fully_addressable
class DeviceIdSet:
_name: str
_ids: frozenset[int]
def __init__(self, name, *ids):
self._name = name
self._ids = frozenset(ids)
def __iter__(self):
return iter(sorted(self._ids))
def __add__(self, other) -> DeviceIdSet:
assert isinstance(other, DeviceIdSet)
return DeviceIdSet(self._name, *(self._ids | other._ids))
def __len__(self) -> int:
return len(self._ids)
def __repr__(self) -> str:
ids = ', '.join(safe_map(str, sorted(self._ids)))
return f'{{{self._name} {ids}}}'
def __hash__(self) -> int:
return hash((self._name, self._ids))
def __eq__(self, other) -> bool:
return (isinstance(other, DeviceIdSet) and self._name == other._name and
self._ids == other._ids)
@use_cpp_class(xc.GSPMDSharding)
class GSPMDSharding(jsharding.Sharding):
_devices: tuple[Device, ...]
_hlo_sharding: xc.HloSharding
_memory_kind: str | None
_device_list: xc.DeviceList | None
_internal_device_list: xc.DeviceList
@use_cpp_method()
def __init__(self, devices: Sequence[Device],
op_sharding: xc.OpSharding | xc.HloSharding,
*, memory_kind: str | None = None,
_device_list: xc.DeviceList | None = None):
self._devices = tuple(devices)
if isinstance(op_sharding, xc.OpSharding):
self._hlo_sharding = xc.HloSharding.from_proto(op_sharding)
else:
self._hlo_sharding = op_sharding
self._memory_kind = memory_kind
def __reduce__(self):
return (type(self), (self._devices, self._hlo_sharding.to_proto()),
{'memory_kind': self._memory_kind})
@functools.cached_property
def _hlo_sharding_hash(self):
if self.is_fully_replicated:
return hash(get_replicated_hlo_sharding())
return hash(self._hlo_sharding)
def __eq__(self, other):
if not isinstance(other, GSPMDSharding):
return False
if self is other:
return True
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self.memory_kind == other.memory_kind
and self._internal_device_list == other._internal_device_list)
def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash((self._internal_device_list, self._hlo_sharding_hash,
self.memory_kind))
return self._hash
def __repr__(self):
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f'GSPMDSharding({self._hlo_sharding!r}{mem})'
def check_compatible_aval(self, aval_shape: Shape) -> None:
num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding)
if len(aval_shape) < len(num_ways_dim_sharded):
raise ValueError(
f"Sharding {self} is only valid for values of rank at least "
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
f"{len(aval_shape)}")
@property
def num_devices(self) -> int:
return len(self.device_set)
@functools.cached_property
def device_set(self) -> set[Device]:
return set(self._devices)
@property
def memory_kind(self) -> str | None:
return self._memory_kind
def with_memory_kind(self, kind: str) -> GSPMDSharding:
return GSPMDSharding(self._devices, self._hlo_sharding, memory_kind=kind)
@property
def _device_assignment(self) -> XLADeviceAssignment:
return self._devices
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return self._hlo_sharding
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
raise NotImplementedError(
"GSPMDSharding can't be converted to SdyArraySharding.")
@functools.cached_property
def is_fully_replicated(self) -> bool:
return is_op_sharding_replicated(self._hlo_sharding)
@functools.cached_property
def is_fully_addressable(self) -> bool:
return self._internal_device_list.is_fully_addressable
@classmethod
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
memory_kind=memory_kind)
class AUTO:
def __init__(self, mesh: mesh_lib.Mesh):
self.mesh = mesh
def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding:
dim_shardings = [SdyDimSharding(axes=[], is_closed=False)
for _ in range(ndim)]
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
class UnspecifiedValue:
def __repr__(self):
return "UnspecifiedValue"
UNSPECIFIED = UnspecifiedValue()
MeshAxisName = Any
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDict[MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
if not array_mapping:
return PartitionSpec()
max_index = -1
reverse_map = collections.defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
if index > max_index:
max_index = index
partitions = []
for i in range(max_index + 1):
axis = reverse_map[i]
if axis:
partitions.append(axis[0] if len(axis) == 1 else tuple(axis))
else:
partitions.append(None)
return PartitionSpec(*partitions)
def get_array_mapping(
axis_resources: ParsedPartitionSpec | AUTO | UnspecifiedValue
) -> ArrayMappingOrAutoOrUnspecified:
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
return axis_resources
return OrderedDict((axis, i)
for i, axes in enumerate(axis_resources)
if axes is not None for axis in axes)
get_single_pspec = lambda p: array_mapping_to_axis_resources(
cast(ArrayMapping, get_array_mapping(p)))
class ParsedPartitionSpec:
__slots__ = ('_user_spec', 'partitions')
def __init__(self, user_spec, partitions):
self._user_spec = user_spec
# None in partitions represents unconstrained dim.
# TODO(yashkatariya): May use a sentinel value.
self.partitions = tuple(partitions)
def get_partition_spec(self) -> PartitionSpec:
if isinstance(self._user_spec, PartitionSpec):
return self._user_spec
else:
return get_single_pspec(self)
def insert_axis_partitions(self, dim, val):
parts = self.partitions
too_short = dim - len(parts)
if too_short > 0:
parts += ((),) * too_short
new_partitions = util.tuple_insert(parts, dim, val)
return ParsedPartitionSpec(None, new_partitions)
@classmethod
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
if entry is None:
return cls(entry, ())
if not isinstance(entry, PartitionSpec):
raise TypeError(f"{arg_name} are expected to be "
f"PartitionSpec instances or None, but got {entry}")
axis_specs = []
for axis_spec in entry:
if axis_spec is None:
axis_spec = ()
elif isinstance(axis_spec, (list, tuple)):
axis_spec = tuple(axis_spec)
elif axis_spec == PartitionSpec.UNCONSTRAINED:
if not allow_unconstrained_dims:
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
axis_spec = None
else:
axis_spec = (axis_spec,)
axis_specs.append(axis_spec)
new_entry = PartitionSpec(
*[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry])
return cls(new_entry, axis_specs)
def __hash__(self):
return hash(self.partitions)
def __eq__(self, other):
if not isinstance(other, ParsedPartitionSpec):
return False
return self.partitions == other.partitions
def __len__(self):
return len(self.partitions)
def __getitem__(self, i):
return self.partitions[i]
def __iter__(self):
return iter(self.partitions)
def __repr__(self):
return f"ParsedPartitionSpec(partitions={self.partitions})"
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
if parsed_pspec is None:
parsed_pspec = prepare_axis_resources(
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
_check_axis_type_consistency(mesh, parsed_pspec)
return parsed_pspec
def prepare_axis_resources(axis_resources, arg_name,
allow_unconstrained_dims=False):
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_util.tree_flatten(
axis_resources, is_leaf=lambda x: x is None)
what = f"{arg_name} leaf specifications"
new_entries = []
for entry in entries:
if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None:
new_entries.append(entry)
elif isinstance(entry, jsharding.Sharding):
if isinstance(entry, PmapSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not '
'allowed.')
new_entries.append(entry)
else:
new_entries.append(ParsedPartitionSpec.from_user_input(
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
_check_unique_resources(new_entries, arg_name)
return tree_util.tree_unflatten(treedef, new_entries)
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, jsharding.Sharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = collections.Counter(
itertools.chain.from_iterable(constrained_dims))
if not resource_counts: continue
if resource_counts.most_common(1)[0][1] > 1:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
if multiple_uses:
raise ValueError(
f'A single {arg_name} specification can map every mesh axis to at'
' most one positional dimension, but'
f' {arg_axis_resources.get_partition_spec()} has duplicate entries'
f' for {mesh_lib.show_axes(multiple_uses)}')
# Axis environments
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: tuple[Any, ...]
sizes: tuple[int, ...]
@dataclasses.dataclass(frozen=True)
class SPMDAxisContext:
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
This includes the mesh that will later by used to execute this computation,
as well as a set of mesh axes that are currently lowered in the MANUAL
sharding mode.
"""
mesh: mesh_lib.Mesh
manual_axes: frozenset[MeshAxisName] = frozenset()
@property
def axis_env(self):
# All collectives that touch axis_env should remember to set use_global_device_ids
# when this context is enabled!
return self.unsafe_axis_env
@property
def unsafe_axis_env(self):
return AxisEnv(
nreps=self.mesh.size,
names=self.mesh.axis_names,
sizes=tuple(self.mesh.shape.values()))
def extend_manual(self, axes: frozenset[MeshAxisName]) -> SPMDAxisContext:
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
@dataclasses.dataclass(frozen=True)
class ReplicaAxisContext:
"""A hardware axis context for parallel computations that are partitioned by JAX.
Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to
explicit collectives.
"""
axis_env: AxisEnv
@dataclasses.dataclass(frozen=True)
class ShardingContext:
"""A hardware axis context for parallel computations that use the sharding
interface.
This context also uses the GSPMD partitioner.
"""
num_devices: int
device_assignment: tuple[xc.Device, ...] | None = None
abstract_mesh: mesh_lib.AbstractMesh | None = None
def __post_init__(self):
if self.device_assignment is not None:
assert isinstance(self.device_assignment, tuple)
assert self.num_devices == len(self.device_assignment)
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
@property
def axis_env(self):
return AxisEnv(nreps=1, names=(), sizes=())
# -------------------- XLA OpSharding to PartitionSpec --------------------
# Note that OpSharding is more expressive than PartitionSpecs, so it's not
# always possible to convert them, but the code below should at least
# support handle all cases when this is possible.
def strides_for_sizes(sizes):
"""Returns an array of strides for major-to-minor sizes."""
return np.cumprod(sizes[::-1])[::-1] // np.asarray(sizes)
def unflatten_array(named_sizes, assignment):
"""Recovers the ordering of axis names based on a device assignment.
The device assignments that this function can convert into axis orders
are of the form::
np.arange(np.prod(named_sizes.values())).transpose(...).flatten()
for some transposition ``...``. This is satisfied by all OpSharding assignments
generated from partition specs.
Arguments:
named_sizes: A dictionary mapping axis names to their sizes.
assignment: A permutation of integers between 0 and the product of all
named sizes.
Returns:
A major-to-minor list of axis names that corresponds to the given assignment.
"""
named_sizes = {name: size for name, size in named_sizes.items() if size != 1}
sizes = np.fromiter(named_sizes.values(), dtype=np.int64)
strides = strides_for_sizes(sizes)
dims = explode_superdims(sizes, unflatten_superdims(assignment))
dim_to_name = {(size, stride): name for size, stride, name in zip(sizes, strides, named_sizes)}
return [dim_to_name[d] for d in dims]
def unflatten_superdims(assignment):
"""Unflatten a list of dimension sizes and their strides that generates assignment.
If this function succeeds for a given ``assignment``, then the following property
should be satisfied::
dims_with_strides = unflatten_superdims(assignment)
base_array = np.arange(map(fst, sorted(dims_with_strides, key=snd, reverse=True)))
assignment == base_array.transpose(argsort(dims_with_strides, key=snd, reverse=True)).flatten()
That is, the returned dimensions list all sizes of the base array (with strides
indicating their initial order). The order of dimensions in the list corresponds
to the permutation that applied to the base array generates the assignment.
"""
def check(cond):
if cond: return
raise NotImplementedError("Failed to convert OpSharding into a ShardingSpec. "
"Please open a bug report!")
flat_assignment = np.asarray(assignment, dtype=np.int64)
check(flat_assignment[0] == 0)
dims = []
while flat_assignment.size > 1:
stride = flat_assignment[1]
for i in range(len(flat_assignment)):
if flat_assignment[i] != i * stride: break
else:
# After this loop i should point to an "element after the sequence", so
# we have to increment it if the whole array is a strided sequence.
i += 1
size = i
dims.append((size, stride))
assert size > 1 # Ensure progress
flat_assignment = flat_assignment[::size]
return dims
def explode_superdims(sizes, dims):
"""Explode superdims to fit a known shape.
The unflattening process might mistakenly generate too few too large dimensions.
For example, ``unflatten_superdims(np.arange(n))`` always returns ``[(n, 1)]``.
This function takes a list of such contiguous super-dimensions and splits them
into smaller dimensions such that::
set(map(fst, explode_superdims(sizes, dims))) == set(sizes)
"""
strides_to_sizes = {stride: size for size, stride in zip(sizes, strides_for_sizes(sizes))}
dims = list(reversed(dims))
final_dims = []
for size, stride in dims:
target_size = strides_to_sizes[stride]
new_dims = []
while size > target_size:
assert target_size > 1 # Ensure progress
assert size % target_size == 0
new_dims.append((target_size, stride))
size //= target_size
stride *= target_size
target_size = strides_to_sizes[stride]
assert size == target_size
new_dims.append((size, stride))
final_dims += reversed(new_dims)
return final_dims
def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding,
mesh: mesh_lib.Mesh) -> Sequence[ParsedPartitionSpec]:
if isinstance(hlo_sharding, xc.OpSharding):
hlo_sharding = xc.HloSharding.from_proto(hlo_sharding)
if hlo_sharding.tuple_elements():
out: list[ParsedPartitionSpec] = []
for s in hlo_sharding.tuple_elements():
out.extend(parse_flatten_op_sharding(s, mesh))
return out
elif hlo_sharding.is_replicated():
return [ParsedPartitionSpec(PartitionSpec(), ())]
elif hlo_sharding.is_tiled():
mesh_shape = mesh.shape
mesh_axis_order = unflatten_array(
mesh.shape, hlo_sharding.tile_assignment_devices()
)
mesh_axis = iter(mesh_axis_order)
shape = hlo_sharding.tile_assignment_dimensions()
partitions = []
for dim_size in shape:
dim_partitions = []
while dim_size > 1:
axis = next(mesh_axis)
axis_size = mesh_shape[axis]
assert dim_size % axis_size == 0
dim_size //= axis_size
dim_partitions.append(axis)
partitions.append(tuple(dim_partitions))
if len(hlo_sharding.subgroup_types()) > 1:
raise NotImplementedError(
'Unhandled HloSharding type. Please open a bug report!'
)
if hlo_sharding.replicate_on_last_tile_dim():
partitions = partitions[:-1]
while partitions and partitions[-1] == ():
partitions.pop()
return [ParsedPartitionSpec(None, partitions)]
else:
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")
def _slice_as_tuple(s: slice):
assert s.step is None
return (s.start, s.stop)
class NonUniformShardingError(ValueError):
"""Raised when sharding is not uniform across processes."""
def get_process_index_and_count(
tensor_sharding: jsharding.Sharding, dim: int, ndims: int) -> tuple[int, int]:
"""Get current process index and number of unique processes for given dimension.
This function facilitates mapping of process-level data to individual
devices. Each process can use its index to obtain the data corresponding
to that index. If process level data is sharded on multiple dimensions
this function can be used to build the cross product of indices in
each sharded axis. Processes that need to load the same data will have
the same index. For shardings whose per-process data is not distributed
on a grid, the number of distinct shards will be such that it is possible to
build the target shape while maintaining a "cube" shape of local-process data.
For example, in case of 4 hosts with sharding distributed like so:
1234
2143
For dim 0 (rows): all processes need to access all rows, so we return (0, 1)
For dim 1 (cols):
process 1 and 2 returns index 0 out of 2 (need cols 0 and 1),
process 3 and 4 returns index 1 out of 2 (need cols 2 and 3).
On the other hand, for a sharding like:
1212
3434
Dim 0 (rows): process 1 and 2 returns (0, 2), process 3 and 4 returns (1, 2)
Dim 1 (cols): process 1 and 3 returns (0, 2), process 2 and 4 returns (1, 2)
Note: This function requires sharding to be process uniform in dimension
`dim`:
each process has the same number of addressable indices in that
dimension and all index sets across processes are either disjoint or the same.
For sharding to be process uniform the addressable shards doesn't need to
form contiguous subtensor, or even a sparse grid and in case of
interleaved high-dimensional tensor it is possible for sharding to be
process uniform only in some dimensions but not others.
For example:
1111 and 12 and 1212 and 1212
2222 21 2121 1212
are all sharding uniform, in both dimensions. However
1122
2121
1121
1222
is uniform in dimension 0 (both hosts access all rows), but
is not uniform in dimension 1 (host 1 accesses columns: 0, 1, and 3),
while host 2 accesses (0, 1, 2, 3).
Returns:
A tuple of (index, num_distinct_shards) for the given dimension.
It is guaranteed that `index` will cover 0 to `num_distinct_shards - 1`,
across all processes.
Raises:
NonUniformShardingError: if the sharding is not process uniform in dimension
`dim`.
"""
# TODO(sandler, yashkatariya): Consider making this function public.
if (tensor_sharding.is_fully_addressable or
tensor_sharding.is_fully_replicated):
return (0, 1)
# Get device to indices map, we don't care about the concrete
# global shape here, only to get the distribution of shards across the tensor
# using (num_devices, num_devices, ...) This is a universal shape that is
# compatible with any mesh with num_devices.
device_map = tensor_sharding.devices_indices_map(
(tensor_sharding.num_devices,) * ndims)
# Get the slices for 'dim' for all devices.
global_slice = {k: v[dim] for k, v in device_map.items()}
# Contains mapping from process_index to a set of slices for that process.
process_to_slice = collections.defaultdict(set)
# Contains global set of slices across all processes.
all_slices = set()
# Compute the set of slices for each process and the global set of slices.
for d, v in global_slice.items():
key = (v.start, v.stop)
process_to_slice[d.process_index].add(key)
all_slices.add(key)
# Get the set of slices for the current process which we will use to compute
# the index of the current process.
current_pid = next(iter(tensor_sharding.addressable_devices)).process_index
addressable_slices = frozenset(process_to_slice[current_pid])
# Verify that all processes have the same number of slices.
slices_per_process = len(addressable_slices)
if any(len(x) != slices_per_process for x in process_to_slice.values()):
raise NonUniformShardingError(
f'{tensor_sharding=} is non-uniform on {dim=} as some processes have '
'different number of slices.'
)
unique_processes = list({frozenset(x) for x in process_to_slice.values()})
# After removing duplicate processes all unique slices should
# cover the dimension exactly once. If they don' it means that
# the sharding is not uniform.
if sum(len(h) for h in unique_processes) != len(all_slices):
raise NonUniformShardingError(
f'{tensor_sharding=} is non-uniform on {dim=}'
)
return (unique_processes.index(addressable_slices), len(unique_processes))
def local_to_global_shape(
sharding: jsharding.Sharding, local_shape: Shape) -> tuple[int | None, ...]:
"""Computes the global shape given the per process if possible.
The returned shape will have the size of the global tensor in that dimension
or None, if it is not computable. The latter can happen when sharding
is not uniform along that dimension, e.g. different hosts require
different shapes, or if different processes have partial data overlap.
If at most one dimension is sharded the shape is always computable.
Generally, global shape is computable for most practical meshes (including
topology aware such as meshes returned by mesh_utils.create_device_mesh)
Some examples: Suppose mesh is {'a': 2, 'b': 2, 'c': 2} with 2 devices
per host, 4 hosts total. For different specs we get:
- P():
global_shape = local_shape
- P(('a', 'b', 'c'), None):
global_shape = (4 * local_shape[0], local_shape[1])
Note: per device shape is (local_shape[0] / 2, local_shape[1])
- P(('a', 'b'), None)
global_shape = (4 * local_shape[0], local_shape[1])
# NB: the same global shape as above, since sharding along 'c' dimension
# happens to be within process, and thus doesn't affect the global shape.
# The underlying difference will be in the per *device* shape, which
# would be (local_shape[0], local_shape[1]) in this case.
- P(None, ('a', 'c'))
global_shape = (local_shape[0], 2 * local_shape[1])
# Per device shape is (local_shape[0], local_shape[1] / 2)
- P(('a', 'c'), 'b'):
global_shape = (2 * local_shape[0], 2 * local_shape[1])
# Per device shape is (local_shape[0] / 2, local_shape[1])
- If devices in the Mesh are randomly permuted: For any partition spec
which shards more than 1 axis: e.g. P('a', ('b', 'c')):
global_shape = (None, None)
Args:
local_shape: global shape of the tensor.
Returns:
global_shape with Nones in non-uniform dimensions.
"""
global_shape : list[int | None] = [None] * len(local_shape)
for i, local_dim in enumerate(local_shape):
try:
_, shard_count = get_process_index_and_count(
sharding, i, ndims=len(local_shape))
global_shape[i] = local_dim * shard_count
except NonUniformShardingError:
global_shape[i] = None
continue
return tuple(global_shape)
def num_addressable_indices(
tensor_sharding: jsharding.Sharding, dim: int, global_shape: Shape) -> int:
"""Returns the number of indices for given dimension this host has access to.
Each host can have multiple number of devices that are spanning
possibly discontiguous slices of data. This function computes the
total number of unique indices for dimension `dim` that any of its
addressable devices hold.
In most cases the addressable indices form a sparse grid (and in some
cases a subcube), and thus each host will hold the same of number of
indices for each dimension. However, it is possible to design a mesh that
addressable shards form a complicated pattern. In that case, the returned
value is the number of indices that are addressable by at least one device.
For example, suppose the sharding looks like this: (number indicates
the host index)
1221
1221
0000
Then on host 1 and 2, both dim 0 (rows), and dim=1 (cols) will have size 2,
while on host 0, dim 0 will have size 1, and dim 1 will have size 4.
Args:
tensor_sharding: Sharding of the tensor.
dim: dimension along which to compute the number of addressable indices.
global_shape: global shape of the tensor.
Returns:
The number of indices for dimension `dim` that this host holds.
"""
# TODO(sandler, yashkatariya): Consider making this function public.
addressables = tensor_sharding.addressable_devices_indices_map(global_shape)
addressables = cast(Mapping[jsharding.Device, Index], addressables)
num_unique_slices = len({
_slice_as_tuple(addressable[dim]) for addressable in addressables.values()
})
shard_size = tensor_sharding.shard_shape(global_shape)[dim]
return shard_size * num_unique_slices
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
elt_aval = core.physical_element_aval(aval.dtype)
new_op_sharding = hlo_sharding.to_proto().clone()
partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * elt_aval.ndim + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)
def is_single_device_sharding(sharding: jsharding.Sharding) -> bool:
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.test_pjit_single_device_sharding_add
return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding)
def make_key_array_phys_sharding(aval, sharding):
if is_single_device_sharding(sharding):
return sharding
elif isinstance(sharding, PmapSharding):
elt_aval = core.physical_element_aval(aval.dtype)
trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim
phys_sharding_spec = sharding_specs.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
elif isinstance(sharding, NamedSharding):
elt_aval = core.physical_element_aval(aval.dtype)
trailing_spec = [None] * elt_aval.ndim
return NamedSharding(
sharding.mesh,
PartitionSpec(*sharding.spec, *trailing_spec))
else:
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
return GSPMDSharding(
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
def physical_sharding(
aval, sharding: jsharding.Sharding) -> jsharding.Sharding:
return make_key_array_phys_sharding(aval, sharding)
def get_logical_gspmd_sharding(aval, phys_sharding):
elt_aval = core.physical_element_aval(aval.dtype)
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + elt_aval.ndim)
partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-elt_aval.ndim] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
def check_replicated_trailing_dims(sharding: jsharding.Sharding, aval):
if isinstance(sharding, PmapSharding):
return
phys_aval = core.physical_aval(aval)
hlo_s = sharding._to_xla_hlo_sharding(phys_aval.ndim)
partitions, _ = get_num_ways_dim_sharded(hlo_s)
num_trailing_dims = phys_aval.ndim - aval.ndim
if not all(i == 1 for i in partitions[-num_trailing_dims:]):
raise AssertionError(
"The trailing dims of extended dtypes should be replicated. Got"
f" sharding: {sharding}, partitions: {partitions}, "
f"num_trailing_dims: {num_trailing_dims}")
def logical_sharding(aval, phys_sharding) -> jsharding.Sharding:
# The trailing dims should always be replicated.
check_replicated_trailing_dims(phys_sharding, aval)
if is_single_device_sharding(phys_sharding):
return phys_sharding
elif isinstance(phys_sharding, PmapSharding):
elt_aval = core.physical_element_aval(aval.dtype)
logical_sharding_spec = sharding_specs.ShardingSpec(
sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim],
mesh_mapping=phys_sharding.sharding_spec.mesh_mapping)
return PmapSharding(devices=phys_sharding.devices,
sharding_spec=logical_sharding_spec)
elif isinstance(phys_sharding, NamedSharding):
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
assert isinstance(phys_sharding.mesh, mesh_lib.Mesh)
return _gspmd_to_named_sharding_via_mesh(
logical_gs, phys_sharding.mesh)
else:
return get_logical_gspmd_sharding(aval, phys_sharding)
@util.cache()
def create_mesh_pspec_sharding(
mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
memory_kind: str | None = None) -> NamedSharding:
if pspec is None:
pspec, parsed_pspec = PartitionSpec(), None
return NamedSharding(mesh, pspec, _parsed_pspec=parsed_pspec,
memory_kind=memory_kind)
def _gspmd_to_named_sharding_via_mesh(
out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding:
parsed_pspec = parse_flatten_op_sharding(
out_s._hlo_sharding, mesh)[0]
return create_mesh_pspec_sharding(
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)
def flatten_spec(spec):
out = []
for s in spec:
if s is None:
continue
if isinstance(s, tuple):
out.extend(s)
else:
out.append(s)
return out
def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
check_mesh_consistency: bool = True
) -> NamedSharding | None:
if not config.sharding_in_types.value:
return sharding # type: ignore
if sharding is None:
return sharding
if isinstance(sharding, PartitionSpec):
sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore
else:
if (check_mesh_consistency and
sharding.mesh != mesh_lib.get_abstract_mesh()):
raise ValueError(
f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh'
f' of sharding {sharding.mesh}. This error occurs at source: '
f' {source_info_util.summarize(source_info_util.current())}')
for s in flatten_spec(sharding.spec):
if sharding.mesh._name_to_type[s] in {
mesh_lib.AxisTypes.Hidden, mesh_lib.AxisTypes.Collective}:
raise ValueError(
'PartitionSpec cannot contain axis names that are of type Hidden or'
f' Collective. Got PartitionSpec: {sharding.spec} with axis name:'
f' {s} or type: {sharding.mesh._name_to_type[s]}')
return sharding
def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
*, devices: Sequence[xc.Device] | None = None,
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
"""Creates an efficient mesh with the shape and axis names specified.
This function attempts to automatically compute a good mapping from a set of
logical axes to a physical mesh. For example, on a TPU v3 with 8 devices:
>>> mesh = jax.make_mesh((8,), ('x')) # doctest: +SKIP
>>> [d.id for d in mesh.devices.flat] # doctest: +SKIP
[0, 1, 2, 3, 6, 7, 4, 5]
The above ordering takes into account the physical topology of TPU v3.
It orders the devices into a ring, which yields efficient all-reduces on a
TPU v3.
Now, let's see another example with 16 devices of TPU v3:
>>> mesh = jax.make_mesh((2, 8), ('x', 'y')) # doctest: +SKIP
>>> [d.id for d in mesh.devices.flat] # doctest: +SKIP
[0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13]
>>> mesh = jax.make_mesh((4, 4), ('x', 'y')) # doctest: +SKIP
>>> [d.id for d in mesh.devices.flat] # doctest: +SKIP
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
As you can see, logical axes (`axis_shapes`) affect the ordering of the
devices.
You can use `jax.experimental.mesh_utils.create_device_mesh` if you want to
use the extra arguments it provides like `contiguous_submeshes` and
`allow_split_physical_axes`.
Args:
axis_shapes: Shape of the mesh. For example, axis_shape=(4, 2)
axis_names: Names of the mesh axes. For example, axis_names=('x', 'y')
devices: Optional keyword only argument, that allows you to specify the
devices you want to create a mesh with.
Returns:
A `jax.sharding.Mesh` object.
"""
if devices is None:
devices = xla_bridge.devices()
new_axis_shapes = mesh_utils._canonicalize_axis_sizes(axis_shapes)
if new_axis_shapes is None:
raise ValueError(
'`axis_shapes` passed to `make_mesh` should be a sequence of ints.'
f' Got {axis_shapes}')
del axis_shapes
axis_size = math.prod(new_axis_shapes)
if axis_size > len(devices):
raise ValueError(
f'Number of devices {len(devices)} must be >= the product '
f'of mesh_shape {new_axis_shapes}')
elif axis_size < len(devices):
devices = devices[:axis_size]
if devices[0].device_kind in (mesh_utils._TPU_V5_LITE, mesh_utils._TPU_V5E):
allow_split_physical_axes = True
else:
allow_split_physical_axes = False
mesh_devices = mesh_utils.create_device_mesh(
new_axis_shapes, devices,
allow_split_physical_axes=allow_split_physical_axes)
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)