mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
573 lines
21 KiB
Python
573 lines
21 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
import collections
|
|
import dataclasses
|
|
import functools
|
|
from typing import Any, Union
|
|
|
|
from jax._src import config
|
|
from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.lib.mlir.dialects import sdy
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
|
|
from jax._src import sharding as JSharding
|
|
from jax._src import xla_bridge as xb
|
|
import numpy as np
|
|
|
|
Shape = tuple[int, ...]
|
|
Device = xc.Device
|
|
Index = tuple[slice, ...]
|
|
XLADeviceAssignment = Sequence[Device]
|
|
|
|
|
|
class AUTO:
|
|
|
|
def __init__(self, mesh: mesh_lib.Mesh):
|
|
self.mesh = mesh
|
|
|
|
def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding:
|
|
dim_shardings = [SdyDimSharding(axes=[], is_closed=False)
|
|
for _ in range(ndim)]
|
|
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
|
|
|
|
class UnspecifiedValue:
|
|
def __repr__(self):
|
|
return "UnspecifiedValue"
|
|
UNSPECIFIED = UnspecifiedValue()
|
|
|
|
|
|
MeshAxisName = Any
|
|
|
|
"""
|
|
ArrayMapping specifies how an ndarray should map to mesh axes.
|
|
|
|
Note that the ordering is crucial for the cases when this mapping is non-injective
|
|
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
|
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
|
according to which chunks of the value along the repeated dimension will be assigned.
|
|
|
|
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
|
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
|
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
|
that would mean that a flat list of chunks would get assigned to a flattened list of
|
|
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
|
mesh devices ndarray would have to be transposed before flattening and assignment.
|
|
"""
|
|
ArrayMapping = collections.OrderedDict[MeshAxisName, int]
|
|
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]
|
|
|
|
|
|
@use_cpp_class(xc.NamedSharding)
|
|
class NamedSharding(JSharding.Sharding):
|
|
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
|
|
|
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
|
|
:class:`PartitionSpec` which describes how to shard an array across that
|
|
mesh.
|
|
|
|
A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
|
|
where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.
|
|
|
|
A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
|
|
a mesh axis, or a tuple of mesh axes. Each element describes how an input
|
|
dimension is partitioned across zero or more mesh dimensions. For example,
|
|
``PartitionSpec('x', 'y')`` says that the first dimension of data
|
|
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
|
|
across ``y`` axis of the mesh.
|
|
|
|
The Distributed arrays and automatic parallelization
|
|
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
|
|
tutorial has more details and diagrams that explain how
|
|
:class:`Mesh` and :class:`PartitionSpec` are used.
|
|
|
|
Args:
|
|
mesh: A :class:`jax.sharding.Mesh` object.
|
|
spec: A :class:`jax.sharding.PartitionSpec` object.
|
|
|
|
Examples:
|
|
|
|
>>> from jax.sharding import Mesh
|
|
>>> from jax.sharding import PartitionSpec as P
|
|
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
|
>>> spec = P('x', 'y')
|
|
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
|
|
"""
|
|
|
|
mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh
|
|
spec: PartitionSpec
|
|
_memory_kind: str | None
|
|
_manual_axes: frozenset[MeshAxisName]
|
|
_logical_device_ids: tuple[int, ...] | None
|
|
|
|
@use_cpp_method()
|
|
def __init__(
|
|
self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *,
|
|
memory_kind: str | None = None, _manual_axes=frozenset(),
|
|
_logical_device_ids=None):
|
|
self.mesh = mesh
|
|
self.spec = spec
|
|
self._memory_kind = memory_kind
|
|
self._manual_axes = _manual_axes
|
|
self._logical_device_ids = _logical_device_ids
|
|
check_pspec(self.mesh, self.spec, self._manual_axes)
|
|
|
|
def __repr__(self):
|
|
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
|
|
ldi = ('' if self._logical_device_ids is None else
|
|
f', logical_device_ids={self._logical_device_ids}')
|
|
mesh_repr = f"{str(self.mesh)}"
|
|
return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})'
|
|
|
|
def __reduce__(self):
|
|
return (type(self), (self.mesh, self.spec),
|
|
{'memory_kind': self.memory_kind,
|
|
'_manual_axes': self._manual_axes,
|
|
'_logical_device_ids': self._logical_device_ids})
|
|
|
|
@property
|
|
def memory_kind(self) -> str | None:
|
|
return self._memory_kind
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash'):
|
|
self._hash = hash(
|
|
(self.mesh, self.memory_kind, self.spec, self._manual_axes,
|
|
self._logical_device_ids))
|
|
return self._hash
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, NamedSharding):
|
|
return False
|
|
if self is other:
|
|
return True
|
|
if (self.spec != other.spec
|
|
or self.memory_kind != other.memory_kind
|
|
or self._manual_axes != other._manual_axes
|
|
or self._logical_device_ids != other._logical_device_ids):
|
|
return False
|
|
return self.mesh is other.mesh or self.mesh == other.mesh
|
|
|
|
def check_compatible_aval(self, aval_shape: Shape) -> None:
|
|
if len(aval_shape) < len(self.spec):
|
|
extra_msg = (' For scalars the PartitionSpec should be P()'
|
|
if len(aval_shape) == 0 else '')
|
|
raise ValueError(
|
|
f"Sharding {self} is only valid for values of rank at least "
|
|
f"{len(self.spec)}, but was applied to a value of rank "
|
|
f"{len(aval_shape)}.{extra_msg}")
|
|
|
|
@property
|
|
def num_devices(self) -> int:
|
|
return self.mesh.size
|
|
|
|
@property
|
|
def device_set(self) -> set[Device]:
|
|
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
|
raise ValueError(
|
|
'device_set is not implemented for `jax.sharding.AbstractMesh`.')
|
|
return self.mesh._flat_devices_set
|
|
|
|
@property
|
|
def _device_assignment(self) -> XLADeviceAssignment:
|
|
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
|
raise ValueError('_device_assignment is not implemented for'
|
|
' `jax.sharding.AbstractMesh`.')
|
|
return self.mesh._flat_devices_tuple
|
|
|
|
@property
|
|
def is_fully_addressable(self) -> bool:
|
|
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
|
raise ValueError('is_fully_addressable is not implemented for '
|
|
'`jax.sharding.AbstractMesh`.')
|
|
# Speed up `is_fully_addressable` since there is a high chance that the
|
|
# mesh across multiple NamedSharding objects will be the same.
|
|
if config.enable_empty_arrays.value:
|
|
client = self._internal_device_list[0].client
|
|
return (len(self.mesh._process_indices) == 1 and
|
|
next(iter(self.mesh._process_indices)) ==
|
|
xb.process_index(client))
|
|
return not self.mesh.is_multi_process
|
|
|
|
@property
|
|
def _is_concrete(self) -> bool:
|
|
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
|
return False
|
|
return True
|
|
|
|
@property
|
|
def addressable_devices(self) -> set[Device]:
|
|
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
|
raise ValueError('addressable_devices is not implemented for '
|
|
'`jax.sharding.AbstractMesh`.')
|
|
# Override addressable devices because there is a high chance that the mesh
|
|
# across multiple NamedSharding objects will be the same.
|
|
return self.mesh._local_devices_set
|
|
|
|
@functools.cached_property
|
|
def is_fully_replicated(self) -> bool:
|
|
if self.mesh.size == 1:
|
|
return True
|
|
array_mapping = get_array_mapping(self.spec)
|
|
mesh_shape = self.mesh.shape
|
|
num_partitions = 1
|
|
for name in array_mapping: # type: ignore
|
|
num_partitions *= mesh_shape[name]
|
|
return num_partitions == 1
|
|
|
|
def with_memory_kind(self, kind: str) -> NamedSharding:
|
|
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
|
|
|
|
def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
|
|
if not isinstance(spec, PartitionSpec):
|
|
spec = PartitionSpec(*spec)
|
|
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
|
|
|
|
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
|
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
|
|
|
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
|
|
dim_shardings = [SdyDimSharding(axes=[], is_closed=True)
|
|
for _ in range(num_dimensions)]
|
|
for i, dim_spec in enumerate(self.spec):
|
|
if dim_spec is PartitionSpec.UNCONSTRAINED:
|
|
dim_shardings[i].is_closed = False
|
|
elif dim_spec is None:
|
|
# Already empty and closed sharding.
|
|
pass
|
|
else:
|
|
dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,)
|
|
dim_shardings[i].axes = dim_spec
|
|
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings,
|
|
self._logical_device_ids)
|
|
|
|
|
|
def get_array_mapping(
|
|
axis_resources: PartitionSpec | AUTO | UnspecifiedValue
|
|
) -> ArrayMappingOrAutoOrUnspecified:
|
|
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
|
|
return axis_resources
|
|
d = collections.OrderedDict()
|
|
for i, axes in enumerate(axis_resources):
|
|
if axes is None or axes is PartitionSpec.UNCONSTRAINED:
|
|
continue
|
|
axes = axes if isinstance(axes, tuple) else (axes,)
|
|
for axis in axes:
|
|
d[axis] = i
|
|
return d
|
|
|
|
@dataclasses.dataclass
|
|
class SdyDimSharding:
|
|
axes: Sequence[str]
|
|
is_closed: bool
|
|
priority: int | None = None
|
|
|
|
def build(self) -> sdy.DimensionShardingAttr:
|
|
return sdy.DimensionShardingAttr.get(
|
|
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
|
|
is_closed=self.is_closed,
|
|
priority=self.priority)
|
|
|
|
def __repr__(self):
|
|
return f'SdyDimSharding({self._custom_repr()})'
|
|
|
|
def _custom_repr(self):
|
|
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
|
|
open_repr = ''
|
|
if not self.is_closed:
|
|
open_repr = ', ?' if self.axes else '?'
|
|
priority_repr = '' if self.priority is None else f'p{self.priority}'
|
|
return f'{{{axes_repr}{open_repr}}}{priority_repr}'
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SdyArraySharding:
|
|
mesh_shape: tuple[tuple[str, int], ...] | None
|
|
dimension_shardings: Sequence[SdyDimSharding]
|
|
logical_device_ids: tuple[int, ...] | None = None
|
|
replicated_axes: tuple[str, ...] = ()
|
|
|
|
def build(self) -> sdy.TensorShardingAttr:
|
|
if self.mesh_shape is None:
|
|
mesh_attr = sdy.MeshAttr.get([])
|
|
else:
|
|
ldi = ([] if self.logical_device_ids is None else
|
|
list(self.logical_device_ids))
|
|
mesh_attr = sdy.MeshAttr.get(
|
|
[sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape],
|
|
ldi)
|
|
return sdy.TensorShardingAttr.get(
|
|
mesh_attr,
|
|
[dim_sharding.build() for dim_sharding in self.dimension_shardings],
|
|
replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes])
|
|
|
|
def __repr__(self):
|
|
dim_sharding_repr = ', '.join(
|
|
d._custom_repr() for d in self.dimension_shardings)
|
|
device_id_repr = (f', device_ids={self.logical_device_ids}'
|
|
if self.logical_device_ids is not None else '')
|
|
rar = (f', replicated_axes={self.replicated_axes}'
|
|
if self.replicated_axes else '')
|
|
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
|
|
|
|
# TODO(yashkatariya): Remove this after jax 0.5.2 release
|
|
class ParsedPartitionSpec:
|
|
__slots__ = ('_user_spec', 'partitions')
|
|
|
|
_user_spec: PartitionSpec | None
|
|
partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...]
|
|
|
|
def __init__(self, user_spec, partitions):
|
|
self._user_spec = user_spec
|
|
assert None not in partitions, partitions
|
|
self.partitions = tuple(partitions)
|
|
|
|
def get_partition_spec(self) -> PartitionSpec:
|
|
if isinstance(self._user_spec, PartitionSpec):
|
|
return self._user_spec
|
|
else:
|
|
return get_single_pspec(self)
|
|
|
|
def insert_axis_partitions(self, dim, val):
|
|
parts = self.partitions
|
|
too_short = dim - len(parts)
|
|
if too_short > 0:
|
|
parts += ((),) * too_short
|
|
new_partitions = tuple_insert(parts, dim, val)
|
|
return ParsedPartitionSpec(None, new_partitions)
|
|
|
|
@classmethod
|
|
def from_user_input(
|
|
cls,
|
|
entry: PartitionSpec | None,
|
|
arg_name: str,
|
|
allow_unconstrained_dims: bool = False,
|
|
) -> ParsedPartitionSpec:
|
|
if entry is None:
|
|
return cls(entry, ())
|
|
if not isinstance(entry, PartitionSpec):
|
|
raise TypeError(f"{arg_name} are expected to be "
|
|
f"PartitionSpec instances or None, but got {entry}")
|
|
axis_specs = []
|
|
for axis_spec in entry:
|
|
if axis_spec is None:
|
|
axis_spec = ()
|
|
elif isinstance(axis_spec, (list, tuple)):
|
|
axis_spec = tuple(axis_spec)
|
|
elif axis_spec is PartitionSpec.UNCONSTRAINED:
|
|
if not allow_unconstrained_dims:
|
|
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
|
|
axis_spec = PartitionSpec.UNCONSTRAINED
|
|
else:
|
|
axis_spec = (axis_spec,)
|
|
axis_specs.append(axis_spec)
|
|
new_entry = PartitionSpec(
|
|
*[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry])
|
|
return cls(new_entry, axis_specs)
|
|
|
|
def __hash__(self):
|
|
return hash(self.partitions)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, ParsedPartitionSpec):
|
|
return False
|
|
return self.partitions == other.partitions
|
|
|
|
def __len__(self):
|
|
return len(self.partitions)
|
|
|
|
def __getitem__(self, i):
|
|
return self.partitions[i]
|
|
|
|
def __iter__(self):
|
|
return iter(self.partitions)
|
|
|
|
def __repr__(self):
|
|
return f"ParsedPartitionSpec(partitions={self.partitions})"
|
|
|
|
@cache(max_size=4096, trace_context_in_key=False)
|
|
def named_sharding_to_xla_hlo_sharding(
|
|
self, num_dimensions: int) -> xc.HloSharding:
|
|
mesh_shape = self.mesh.shape
|
|
array_mapping = get_array_mapping(self.spec)
|
|
mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)}
|
|
|
|
special_axes = {}
|
|
mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items()
|
|
if t == mesh_lib.AxisType.Manual}
|
|
manual_axes = self._manual_axes.union(mesh_manual_axes)
|
|
if manual_axes:
|
|
axis_names = self.mesh.axis_names
|
|
for manual_axis in manual_axes:
|
|
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
|
|
|
|
replicated_mesh_axes = []
|
|
for i, (axis_name, axis_val) in enumerate(mesh_shape.items()):
|
|
if axis_name not in array_mapping: # type: ignore
|
|
replicated_mesh_axes.append((i, axis_val))
|
|
|
|
if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes:
|
|
return xc.HloSharding.replicate()
|
|
|
|
mesh_permutation = []
|
|
new_mesh_shape = [1] * num_dimensions
|
|
for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): # type: ignore
|
|
new_mesh_shape[pos] *= mesh_shape[name]
|
|
mesh_permutation.append(mesh_axis_pos[name])
|
|
|
|
last_tile_dims = []
|
|
if replicated_mesh_axes:
|
|
axes_by_type = collections.defaultdict(list)
|
|
size_by_type = collections.defaultdict(lambda: 1) # type: ignore
|
|
assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys()))
|
|
for i, size in replicated_mesh_axes:
|
|
ty = special_axes.get(i, xc.OpSharding.Type.REPLICATED)
|
|
axes_by_type[ty].append(i)
|
|
size_by_type[ty] *= size
|
|
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
|
|
last_tile_dims.append(ty)
|
|
new_mesh_shape.append(size_by_type[ty])
|
|
mesh_permutation.extend(axes)
|
|
|
|
# Explanation of the parameters of `HloSharding.iota_tile`.
|
|
# This is the HloShardingV2 format:
|
|
# * dims: How many ways each dimension is sharded.
|
|
# Replicated/Manual dims are added added at the end
|
|
# * reshape_dims: This is the just the shape of the mesh.
|
|
# * transpose_perm: This is the order in which mesh axes in PartitionSpec
|
|
# appear relative to mesh.axis_names order.
|
|
# * subgroup_types: List of type of OpSharding. Type can be REPLICATED and MANUAL.
|
|
# Let's see an example:
|
|
# Consider input_shape=(8, 4, 2, 2), mesh={'a': 2, 'b': 2, 'c': 2, 'd': 2}
|
|
# and partition_spec=P(None, ('d', 'b'), 'c').
|
|
# Arguments to iota_tile will be:
|
|
# dims = [1, 4, 2, 1, 2] # 'a' is replicated hence `2` is at the end.
|
|
# reshape_dims = [2, 2, 2, 2]
|
|
# transpose_perm = [3, 1, 2, 0] # 'a' is replicated hence 0 is at the end
|
|
# subgroup_types = [xc.OpSharding.Type.REPLICATED]
|
|
dims = new_mesh_shape
|
|
reshape_dims = self.mesh.axis_sizes
|
|
if self._logical_device_ids is None:
|
|
return xc.HloSharding.iota_tile(
|
|
dims=dims, reshape_dims=reshape_dims, transpose_perm=mesh_permutation,
|
|
subgroup_types=last_tile_dims)
|
|
else:
|
|
return xc.HloSharding.subgroup_with_device_ordering(
|
|
np.asarray(self._logical_device_ids)
|
|
.reshape(dims).reshape(reshape_dims).transpose(mesh_permutation)
|
|
.reshape(dims), subgroup_types=last_tile_dims)
|
|
|
|
|
|
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
|
if not array_mapping:
|
|
return PartitionSpec()
|
|
max_index = -1
|
|
reverse_map = collections.defaultdict(list)
|
|
for axis, index in array_mapping.items():
|
|
reverse_map[index].append(axis)
|
|
if index > max_index:
|
|
max_index = index
|
|
partitions = []
|
|
for i in range(max_index + 1):
|
|
axis = reverse_map[i]
|
|
if axis:
|
|
partitions.append(axis[0] if len(axis) == 1 else tuple(axis))
|
|
else:
|
|
partitions.append(None)
|
|
return PartitionSpec(*partitions)
|
|
|
|
get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore
|
|
|
|
# TODO(yashkatariya): Remove this after jax 0.5.2 release
|
|
def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
|
if parsed_pspec is None:
|
|
spec = PartitionSpec() if spec is None else spec
|
|
parsed_pspec = ParsedPartitionSpec.from_user_input(
|
|
spec, "NamedSharding spec", allow_unconstrained_dims=True)
|
|
_check_unique_resources(parsed_pspec, "NamedSharding spec", mesh)
|
|
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
|
return parsed_pspec
|
|
|
|
def check_pspec(mesh, spec, _manual_axes=frozenset()):
|
|
_check_unique_resources(spec, "NamedSharding spec", mesh)
|
|
_check_mesh_resource_axis(mesh, spec, _manual_axes)
|
|
|
|
class DuplicateSpecError(Exception):
|
|
def __init__(self, message, mesh, pspec):
|
|
super().__init__(message)
|
|
self.message = message
|
|
self.mesh = mesh
|
|
self.pspec = pspec
|
|
|
|
def __str__(self):
|
|
return f"{self.message}"
|
|
|
|
def _check_unique_resources(
|
|
pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None,
|
|
) -> None:
|
|
resource_counts: dict[MeshAxisName, int] = {}
|
|
duplicate = False
|
|
pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec)
|
|
else pspec)
|
|
for d in pspec:
|
|
if d is PartitionSpec.UNCONSTRAINED or d is None:
|
|
continue
|
|
d = d if isinstance(d, tuple) else (d,)
|
|
for resource in d:
|
|
count = resource_counts.get(resource, 0)
|
|
if count > 0:
|
|
duplicate = True
|
|
resource_counts[resource] = count + 1
|
|
if duplicate:
|
|
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
|
raise DuplicateSpecError(
|
|
message=(
|
|
f'A single {arg_name} specification can map every mesh axis to at'
|
|
f' most one positional dimension, but {pspec} has duplicate entries'
|
|
f' for {mesh_lib.show_axes(multiple_uses)}'),
|
|
mesh=mesh, pspec=pspec)
|
|
|
|
@cache(max_size=128, trace_context_in_key=False)
|
|
def _check_mesh_resource_axis(mesh, pspec, _manual_axes):
|
|
pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec)
|
|
else pspec)
|
|
for p in pspec:
|
|
if p is PartitionSpec.UNCONSTRAINED or p is None:
|
|
continue
|
|
p = p if isinstance(p, tuple) else (p,)
|
|
for r in p:
|
|
if r not in mesh.shape:
|
|
raise ValueError(
|
|
f"Resource axis: {r} of {pspec} "
|
|
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
|
|
if r in _manual_axes:
|
|
raise ValueError(
|
|
f"Axis: {r} of {pspec} "
|
|
f"is also found in manual_axes: {_manual_axes}.") from None
|
|
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
|
|
raise ValueError(
|
|
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
|
|
f' {pspec}. Got subset {p} with axis'
|
|
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
|
|
if (mesh_lib.AxisType.Auto not in mesh._axis_types_dict and
|
|
PartitionSpec.UNCONSTRAINED in pspec):
|
|
raise ValueError(
|
|
f'{pspec} cannot contain'
|
|
' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
|
|
f' axis_types: {mesh._axis_types_dict}')
|