rocm_jax/jax/_src/sharding_specs.py
2024-04-07 21:13:41 -07:00

207 lines
8.4 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A ShardingSpec describes at a high level how a logical array is sharded across
# devices (each array sharded with a `PmapSharding` has a ShardingSpec, and
# ShardingSpecs also describe how to shard inputs to a parallel computation).
# spec_to_indices() encodes exactly how a given ShardingSpec is translated to
# device buffers, i.e. how the sharded array is "laid out" across devices. Given
# a sequence of devices, we shard the data across the devices in row-major
# order, with replication treated as an extra inner dimension.
#
# For example, given the logical data array [1, 2, 3, 4], if we were to
# partition this array 4 ways with a replication factor of 2, for a total of 8
# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4].
#
# This encoding is assumed by various parts of the system, e.g. generating
# replica groups for collective operations.
from __future__ import annotations
from collections.abc import Sequence
import itertools
import math
from typing import Union
import numpy as np
from jax._src import config
from jax._src import util
from jax._src.lib import pmap_lib
unsafe_map, map = map, util.safe_map
NoSharding = pmap_lib.NoSharding
Chunked = pmap_lib.Chunked
Unstacked = pmap_lib.Unstacked
_UNSHARDED_INSTANCE = NoSharding()
ShardedAxis = pmap_lib.ShardedAxis
Replicated = pmap_lib.Replicated
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = pmap_lib.ShardingSpec
def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
Args:
shape: The shape of the logical array being sharded.
Returns:
An ndarray with the same shape as the logical mesh (as derived form
`mesh_mapping`). Each entry is a NumPy-style index selecting the subset of
the data array to be placed on a corresponding device. The indices can be
ints, slice objects with step=1, or tuples of those.
"""
assert len(shape) == len(self.sharding), (shape, self.sharding)
axis_indices: list[Sequence[Index]] = []
shard_indices_shape = []
for dim, sharding in enumerate(self.sharding):
axis_size = shape[dim]
if isinstance(sharding, NoSharding):
axis_indices.append([slice(None)])
# NOTE: We don't append unsharded dimensions to shard_indices_shape here,
# because they do not appear in the mesh mapping.
elif isinstance(sharding, Unstacked):
assert axis_size == sharding.size, f'{axis_size} != {sharding.size}'
axis_indices.append(range(axis_size))
shard_indices_shape.append(axis_size)
elif isinstance(sharding, Chunked):
total_chunks = math.prod(sharding.chunks)
shard_size, ragged = divmod(axis_size, total_chunks)
assert not ragged, (axis_size, total_chunks, dim)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
for i in range(total_chunks)])
shard_indices_shape.extend(sharding.chunks)
else:
util.assert_unreachable(sharding)
# shard_indices is an ndarray representing the sharded axes of the logical array,
# with each dimension having size equal to the number of shards across the corresponding
# logical array dimension, and each element containing the multi-dimensional index that
# is used to extract the corresponding shard of the logical array.
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
for i, idxs in enumerate(itertools.product(*axis_indices)):
shard_indices[i] = idxs
shard_indices = shard_indices.reshape(shard_indices_shape)
# Ensure that each sharded axis is used exactly once in the mesh mapping
num_sharded_dim = len(shard_indices_shape)
sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)]
assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and
len(sharded_dim_perm) == num_sharded_dim)
# Replicate/reorder the indices according to the mesh mapping
replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated))
replica_dim, sharded_dim = itertools.count(0), iter(sharded_dim_perm)
perm = [next(replica_dim) if isinstance(a, Replicated) else
len(replica_sizes) + next(sharded_dim)
for a in self.mesh_mapping]
return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape)
.transpose(perm))
def _sharding_spec_repr(self):
return f'ShardingSpec({self.sharding}, {self.mesh_mapping})'
ShardingSpec.indices = _sharding_spec_indices
# mypy raises: error: Cannot assign to a method [assignment]
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
Index = Union[int, slice, tuple[Union[int, slice], ...]]
def spec_to_indices(shape: Sequence[int],
spec: ShardingSpec) -> tuple[Index, ...]:
"""Returns numpy-style indices corresponding to a sharding spec.
Each index describes a shard of the array. The order of the indices is the
same as the device_buffers of a Array sharded using PmapSharding (i.e. the
data is laid out row-major).
Args:
shape: The shape of the logical array being sharded.
spec: Describes how the array is sharded and how the shards are assigned to
the logical mesh.
Returns:
A tuple of length equal to the size of the mesh (inferred as the product of
sharded dimension sizes and all replication factors). Each element is an
int, a slice object with step=1, or a tuple thereof, to be treated as an
index into the full logical array.
"""
return tuple(spec.indices(shape).flat) # type: ignore
def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
map_axis: int | None) -> ShardingSpec:
"""Sharding spec for arguments or results of a pmap.
Args:
nrep: number of local XLA replicas (product of local axis sizes)
axis_size: local axis size for outer pmap
sharded_aval: the aval of the value inside the outer pmap, an instance of
a ShapedArray.
map_axis: the axis along which the value is mapped in the outer pmap
Returns:
A ShardingSpec.
"""
replication_factor, ragged = divmod(nrep, axis_size)
assert not ragged
pspec = ShardingSpec(sharding=[_UNSHARDED_INSTANCE] * len(sharded_shape),
mesh_mapping=())
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
if map_axis is not None:
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
def shift_sharded_axis(a: MeshDimAssignment):
if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis:
return ShardedAxis(a.axis + 1)
return a
# replication_factor represents the product of inner pmaps, so it goes
# after the outer pmapped axis at index 0
if config.pmap_no_rank_reduction.value:
sharding = util.tuple_update(
pspec.sharding, map_axis, Chunked([axis_size]))
else:
sharding = util.tuple_insert(
pspec.sharding, map_axis, Unstacked(axis_size))
return ShardingSpec(
sharding=sharding,
mesh_mapping=itertools.chain(
[ShardedAxis(sharded_in_axis)], maybe_replicate,
map(shift_sharded_axis, pspec.mesh_mapping)))
else:
return ShardingSpec(
sharding=pspec.sharding,
mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping)
def create_pmap_sharding_spec(shape: tuple[int, ...], sharded_dim: int = 0,
sharded_dim_size: int | None = None):
if sharded_dim is not None:
if config.pmap_no_rank_reduction.value:
sharded_shape = util.tuple_update(shape, sharded_dim, 1)
else:
sharded_shape = util.tuple_delete(shape, sharded_dim)
if sharded_dim_size is None:
sharded_dim_size = shape[sharded_dim]
else:
assert sharded_dim_size is not None
sharded_shape = shape
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, sharded_shape,
sharded_dim)