diff --git a/jax/BUILD b/jax/BUILD index 7b81d4d8d..01fc9b1e2 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -182,6 +182,7 @@ py_library_providing_imports_info( ":pretty_printer", ":profiler", ":sharding", + ":sharding_specs", ":source_info_util", ":traceback_util", ":tree_util", @@ -411,6 +412,16 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "sharding_specs", + srcs = ["_src/sharding_specs.py"], + deps = [ + ":op_shardings", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "source_info_util", srcs = ["_src/source_info_util.py"], diff --git a/jax/_src/api.py b/jax/_src/api.py index 7b9cf8be4..1435586fd 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -45,6 +45,7 @@ from jax._src import dispatch from jax._src import effects from jax._src import array from jax._src import dtypes +from jax._src import sharding_specs from jax._src import source_info_util from jax._src import traceback_util from jax._src import pjit @@ -1782,13 +1783,13 @@ class _PmapFastpathData(NamedTuple): out_handler: Any out_pytree_def: Any # Data needed to handle the inputs. - input_sharding_specs: Sequence[pxla.ShardingSpec] + input_sharding_specs: Sequence[sharding_specs.ShardingSpec] input_devices: Sequence[xc.Device] - input_indices: Sequence[pxla.Index] + input_indices: Sequence[sharding_specs.Index] input_array_shardings: Sequence[Any] # Data needed to build the Array from C++. - out_sharding_specs: Sequence[pxla.ShardingSpec] - out_indices: Sequence[pxla.Index] + out_sharding_specs: Sequence[sharding_specs.ShardingSpec] + out_indices: Sequence[sharding_specs.Index] out_avals: Sequence[Any] out_array_shardings: Sequence[Any] out_committed: Sequence[Any] @@ -2582,7 +2583,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # raise ValueError("the shards passed to device_put_sharded must have " f"consistent shape and dtype, but got {a1} and {a2}.") stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) - sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval) + sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) return pxla.batched_device_put( stacked_aval, PmapSharding(np.array(devices), sharding_spec), xs, list(devices)) @@ -2630,7 +2631,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 core.raise_to_shaped(core.get_aval(x))) assert (isinstance(aval, ShapedArray) and len(xla.aval_to_xla_shapes(aval)) == 1) - sharding_spec = pxla._create_pmap_sharding_spec(aval) + sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) buf = device_put(x, devices[0]) return pxla.batched_device_put( aval, PmapSharding(np.array(devices), sharding_spec), diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c674a528b..5f47ee5fd 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -13,21 +13,6 @@ # limitations under the License. """Implementation of pmap and related functionality.""" -# A ShardingSpec describes at a high level how a logical array is sharded across -# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also -# describe how to shard inputs to a parallel computation). spec_to_indices() -# encodes exactly how a given ShardingSpec is translated to device buffers, i.e. -# how the sharded array is "laid out" across devices. Given a sequence of -# devices, we shard the data across the devices in row-major order, with -# replication treated as an extra inner dimension. -# -# For example, given the logical data array [1, 2, 3, 4], if we were to -# partition this array 4 ways with a replication factor of 2, for a total of 8 -# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4]. -# -# This encoding is assumed by various parts of the system, e.g. generating -# replica groups for collective operations. - from __future__ import annotations import enum @@ -40,8 +25,8 @@ import logging import math import sys from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, - Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast, - TYPE_CHECKING) + Sequence, Set, Tuple, Type, Union, Iterable, + TYPE_CHECKING, cast) import numpy as np @@ -57,6 +42,7 @@ from jax._src import effects from jax._src import linear_util as lu from jax._src import mesh from jax._src import op_shardings +from jax._src import sharding_specs from jax._src import profiler from jax._src import sharding_impls from jax._src import source_info_util @@ -73,12 +59,10 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.lib import xla_client as xc -from jax._src.lib import pmap_lib from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.util import (unzip3, safe_map, safe_zip, partition_list, - wrap_name, assert_unreachable, - tuple_insert, tuple_delete, distributed_debug_log, + wrap_name, tuple_delete, distributed_debug_log, unzip2, HashableFunction) @@ -100,214 +84,23 @@ logger = logging.getLogger(__name__) Index = Union[int, slice, Tuple[Union[int, slice], ...]] -NoSharding = pmap_lib.NoSharding -Chunked = pmap_lib.Chunked -Unstacked = pmap_lib.Unstacked +NoSharding = sharding_specs.NoSharding +Chunked = sharding_specs.Chunked +Unstacked = sharding_specs.Unstacked -ShardedAxis = pmap_lib.ShardedAxis -Replicated = pmap_lib.Replicated +ShardedAxis = sharding_specs.ShardedAxis +Replicated = sharding_specs.Replicated -_UNSHARDED_INSTANCE = NoSharding() AvalDimSharding = Union[Unstacked, Chunked, NoSharding] Mesh = jax._src.mesh.Mesh MeshAxisName = mesh.MeshAxisName MeshDimAssignment = Union[ShardedAxis, Replicated] -ShardingSpec = pmap_lib.ShardingSpec +ShardingSpec = sharding_specs.ShardingSpec -OpShardingType = Any PartitionSpec = sharding_impls.PartitionSpec -def sharding_spec_mesh_shape(self): - sharded_axis_sizes = [] - for sharding in self.sharding: - if isinstance(sharding, NoSharding): - continue - elif isinstance(sharding, Unstacked): - sharded_axis_sizes.append(sharding.size) - elif isinstance(sharding, Chunked): - sharded_axis_sizes.extend(sharding.chunks) - else: - assert_unreachable(sharding) - return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) else a.replicas - for a in self.mesh_mapping) - - -def _get_logical_mesh_ids(mesh_shape): - return np.arange(np.prod(mesh_shape)).reshape(mesh_shape) - - -def sharding_spec_sharding_proto(self, special_axes: Mapping[int, OpShardingType] = {}): - """Converts a ShardingSpec to an OpSharding proto. - - See - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601 - for details on the OpSharding proto. - Unfortunately the semantics are not very well described in the proto spec, but the code here might help: - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py - """ - mesh_shape = cast(Tuple[int, ...], self.mesh_shape) - mesh = _get_logical_mesh_ids(self.mesh_shape) - - sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped - replicated_maxes = [] # lists mesh axis identifiers to replicate over - for maxis, assignment in enumerate(self.mesh_mapping): - if isinstance(assignment, Replicated): - replicated_maxes.append((maxis, assignment.replicas)) - elif isinstance(assignment, ShardedAxis): - sharded_axes[assignment.axis] = maxis - else: - assert_unreachable(assignment) - - proto = xc.OpSharding() - if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes: - proto.type = xc.OpSharding.Type.REPLICATED - return proto - else: - proto.type = xc.OpSharding.Type.OTHER - - mesh_permutation = [] - new_mesh_shape = [] - next_sharded_axis = 0 - for axis, sharding in enumerate(self.sharding): - if isinstance(sharding, NoSharding): - new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over - elif isinstance(sharding, Chunked): - for nchunks in sharding.chunks: - maxis = sharded_axes[next_sharded_axis] - assert mesh_shape[maxis] == nchunks - mesh_permutation.append(maxis) - next_sharded_axis += 1 - new_mesh_shape.append(int(np.prod(sharding.chunks))) - elif isinstance(sharding, Unstacked): - raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding") - else: - assert_unreachable(sharding) - - # Create a partial sharding proto if tensor is replicated or partitioned - # specially over some mesh axes. - if replicated_maxes: - last_tile_dims = [] - axes_by_type: Dict[OpShardingType, List[jax._src.mesh.MeshAxisName]] = {} - size_by_type: Dict[OpShardingType, int] = defaultdict(lambda: 1) - assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys())) - for axis, size in replicated_maxes: - ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED) - axes_by_type.setdefault(ty, []).append(axis) - size_by_type[ty] *= size - for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value): - last_tile_dims.append(ty) - new_mesh_shape.append(size_by_type[ty]) - mesh_permutation.extend(axes) - proto.last_tile_dims = last_tile_dims - - proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape) - proto.tile_assignment_dimensions = list(proto_mesh.shape) - proto.tile_assignment_devices = list(proto_mesh.flat) - return proto - - -def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: - """Returns NumPy-style indices corresponding to a sharding spec. - - Args: - shape: The shape of the logical array being sharded. - - Returns: - An ndarray with the same shape as the logical mesh (as derived form - `mesh_mapping`). Each entry is a NumPy-style index selecting the subset of - the data array to be placed on a corresponding device. The indices can be - ints, slice objects with step=1, or tuples of those. - """ - assert len(shape) == len(self.sharding), (shape, self.sharding) - - has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding) - # Take the op sharding indices generation route for pjit/xmap cases. - if not has_unstacked: - op_sharding_proto = sharding_spec_sharding_proto(self) - return op_shardings.op_sharding_to_numpy_indices( - op_sharding_proto, shape, math.prod(self.mesh_shape) - ).reshape(self.mesh_shape) - - axis_indices: List[Sequence[Index]] = [] - shard_indices_shape = [] - for dim, sharding in enumerate(self.sharding): - axis_size = shape[dim] - if isinstance(sharding, NoSharding): - axis_indices.append([slice(None)]) - # NOTE: We don't append unsharded dimensions to shard_indices_shape here, - # because they do not appear in the mesh mapping. - elif isinstance(sharding, Unstacked): - assert axis_size == sharding.size, f'{axis_size} != {sharding.size}' - axis_indices.append(range(axis_size)) - shard_indices_shape.append(axis_size) - elif isinstance(sharding, Chunked): - total_chunks = int(np.prod(sharding.chunks)) - shard_size, ragged = divmod(axis_size, total_chunks) - assert not ragged, (axis_size, total_chunks, dim) - axis_indices.append([slice(i * shard_size, (i + 1) * shard_size) - for i in range(total_chunks)]) - shard_indices_shape.extend(sharding.chunks) - else: - assert_unreachable(sharding) - - # shard_indices is an ndarray representing the sharded axes of the logical array, - # with each dimension having size equal to the number of shards across the corresponding - # logical array dimension, and each element containing the multi-dimensional index that - # is used to extract the corresponding shard of the logical array. - shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_) - for i, idxs in enumerate(it.product(*axis_indices)): - shard_indices[i] = idxs - shard_indices = shard_indices.reshape(shard_indices_shape) - - # Ensure that each sharded axis is used exactly once in the mesh mapping - num_sharded_dim = len(shard_indices_shape) - sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)] - assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and - len(sharded_dim_perm) == num_sharded_dim) - # Replicate/reorder the indices according to the mesh mapping - replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated)) - replica_dim, sharded_dim = it.count(0), iter(sharded_dim_perm) - perm = [next(replica_dim) if isinstance(a, Replicated) else - len(replica_sizes) + next(sharded_dim) - for a in self.mesh_mapping] - return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape) - .transpose(perm)) - -def sharding_spec_repr(self): - return f'ShardingSpec({self.sharding}, {self.mesh_mapping})' - - -ShardingSpec.mesh_shape = property(sharding_spec_mesh_shape) -ShardingSpec.sharding_proto = sharding_spec_sharding_proto -ShardingSpec.indices = sharding_spec_indices -# mypy raises: error: Cannot assign to a method [assignment] -ShardingSpec.__repr__ = sharding_spec_repr # type: ignore -# Do not pollute the namespace -del sharding_spec_mesh_shape, sharding_spec_indices, sharding_spec_repr - -def spec_to_indices(shape: Sequence[int], - spec: ShardingSpec) -> Tuple[Index, ...]: - """Returns numpy-style indices corresponding to a sharding spec. - - Each index describes a shard of the array. The order of the indices is the - same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out - row-major). - - Args: - shape: The shape of the logical array being sharded. - spec: Describes how the array is sharded and how the shards are assigned to - the logical mesh. - - Returns: - A tuple of length equal to the size of the mesh (inferred as the product of - sharded dimension sizes and all replication factors). Each element is an - int, a slice object with step=1, or a tuple thereof, to be treated as an - index into the full logical array. - """ - return tuple(spec.indices(shape).flat) # type: ignore - ### util @@ -554,21 +347,6 @@ global_result_handlers: Dict[Type[core.AbstractValue], PxlaResultHandler] = {} ### lazy device-memory persistence and result handling - -def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None): - if sharded_dim is not None: - sharded_aval = aval.update( - shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:]) - if sharded_dim_size is None: - sharded_dim_size = aval.shape[sharded_dim] - else: - assert sharded_dim_size is not None - sharded_aval = aval - - return _pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None, - sharded_aval, sharded_dim) - - # TODO(yashkatariya, phawkins): Remove this function after March 15, 2023. def make_sharded_device_array( aval: ShapedArray, @@ -594,7 +372,7 @@ def make_sharded_device_array( from jax._src import pjit if sharding_spec is None: - sharding_spec = _create_pmap_sharding_spec(aval) + sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) mesh = jax._src.mesh.thread_resources.env.physical_mesh sharding: sharding_impls.XLACompatibleSharding @@ -602,7 +380,7 @@ def make_sharded_device_array( sharding = sharding_impls.PmapSharding( np.asarray([d.device() for d in device_buffers]), sharding_spec) else: - op_sharding = sharding_spec_sharding_proto(sharding_spec) + op_sharding = sharding_specs.sharding_spec_sharding_proto(sharding_spec) pspec = pjit.parse_flatten_op_sharding( op_sharding, mesh)[0].get_partition_spec() sharding = sharding_impls.NamedSharding(mesh, pspec) @@ -1323,8 +1101,10 @@ class UnloadedPmapExecutable: local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals) input_sharding_specs = [ - _pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size, - parts.local_num_partitions, arg_parts, aval, in_axis) + sharding_specs.pmap_sharding_spec( + replicas.num_local_replicas, pci.axis_size, + parts.local_num_partitions, arg_parts, + cast(ShapedArray, aval).shape, in_axis) for aval, arg_parts, in_axis in safe_zip( shards.sharded_avals, local_arg_parts_, pci.in_axes)] in_shardings = _get_pmap_sharding(local_device_assignment, input_sharding_specs) @@ -1342,15 +1122,16 @@ class UnloadedPmapExecutable: if out_axis is not None else aval for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)] out_specs = [ - _pmap_sharding_spec(replicas.num_local_replicas, pci.axis_size, - parts.local_num_partitions, out_parts, aval, out_axis) + sharding_specs.pmap_sharding_spec( + replicas.num_local_replicas, pci.axis_size, + parts.local_num_partitions, out_parts, aval.shape, out_axis) for out_parts, aval, out_axis in safe_zip( local_out_parts, local_out_avals, pci.out_axes)] out_shardings = _get_pmap_sharding(local_device_assignment, out_specs) if hasattr(pci.backend, "compile_replicated"): input_indices = [ - spec_to_indices(aval.shape, spec) # pytype: disable=attribute-error + sharding_specs.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in safe_zip(pci.avals, input_sharding_specs) ] @@ -1384,8 +1165,9 @@ class UnloadedPmapExecutable: input_indices = [] for aval, spec in safe_zip(self.local_input_avals, self.input_shardings): assert isinstance(spec, sharding_impls.PmapSharding), spec - input_indices.append(spec_to_indices(aval.shape, spec.sharding_spec) - if spec.sharding_spec is not None else None) + input_indices.append( + sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec) + if spec.sharding_spec is not None else None) handle_outs = local_avals_to_results_handler(self.local_output_avals, self.output_shardings) handle_args = InputsHandler(self.compiled.local_devices(), @@ -1686,7 +1468,8 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0): else: replicated_aval = aval # TODO(skye): figure out how partitioning should work here - sharding_spec = _pmap_sharding_spec(nrep, axis_size, 1, None, aval, in_axis) + sharding_spec = sharding_specs.pmap_sharding_spec( + nrep, axis_size, 1, None, aval.shape, in_axis) buf = jax.device_put(val, devices[0]) sharding = sharding_impls.PmapSharding( @@ -1695,59 +1478,6 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0): devices) -def _pmap_sharding_spec(nrep, axis_size, npart, parts, sharded_aval, - map_axis: Optional[int]) -> ShardingSpec: - """Sharding spec for arguments or results of a pmap. - Args: - nrep: number of local XLA replicas (product of local axis sizes) - axis_size: local axis size for outer pmap - npart: total number of XLA partitions (required by sharded_jit calls) - parts: the partitioning of the value or None - sharded_aval: the aval of the value inside the outer pmap, an instance of - a ShapedArray. - map_axis: the axis along which the value is mapped in the outer pmap - Returns: - A ShardingSpec. - """ - assert isinstance(sharded_aval, ShapedArray), sharded_aval - replication_factor, ragged = divmod(nrep, axis_size) - assert not ragged - # get the sharding spec from inner sharded_jits as if we weren't in a pmap - pspec = partitioned_sharding_spec(npart, parts, sharded_aval) - maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),) - if map_axis is not None: - sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis]) - def shift_sharded_axis(a: MeshDimAssignment): - if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis: - return ShardedAxis(a.axis + 1) - return a - # replication_factor represents the product of inner pmaps, so it goes - # after the outer pmapped axis at index 0 - return ShardingSpec( - sharding=tuple_insert(pspec.sharding, map_axis, Unstacked(axis_size)), - mesh_mapping=it.chain([ShardedAxis(sharded_in_axis)], - maybe_replicate, - map(shift_sharded_axis, pspec.mesh_mapping))) - else: - return ShardingSpec( - sharding=pspec.sharding, - mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping) - -def partitioned_sharding_spec(num_partitions: int, - partitions: Optional[Sequence[int]], - aval) -> ShardingSpec: - if partitions is None: - maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),) - return ShardingSpec( - sharding=[_UNSHARDED_INSTANCE] * len(aval.shape), - mesh_mapping=maybe_replicate) - else: - assert len(partitions) == len(aval.shape) - return ShardingSpec( - # Chunked expects a list of integers - sharding=map(Chunked, [[x] for x in partitions]), - mesh_mapping=map(ShardedAxis, range(len(partitions)))) - class ExecuteReplicated: """The logic to shard inputs, execute a replicated model, returning outputs.""" @@ -2893,18 +2623,18 @@ class UnloadedMeshExecutable: use_auto_spmd_partitioning=auto_spmd_lowering, env_options_overrides=compiler_options, ) + opts = compile_options.executable_build_options if auto_spmd_lowering: assert mesh is not None - compile_options.executable_build_options.auto_spmd_partitioning_mesh_shape = \ - list(mesh.shape.values()) - compile_options.executable_build_options.auto_spmd_partitioning_mesh_ids = \ - _get_logical_mesh_ids(list(mesh.shape.values())).reshape(-1) + opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values()) + opts.auto_spmd_partitioning_mesh_ids = ( + sharding_specs.get_logical_mesh_ids(list(mesh.shape.values())) + .reshape(-1)) compile_options.parameter_is_tupled_arguments = tuple_args if _allow_propagation_to_outputs is None: _allow_propagation_to_outputs = [False] * len(out_shardings) - compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \ - _allow_propagation_to_outputs + opts.allow_spmd_sharding_propagation_to_output = _allow_propagation_to_outputs if hasattr(backend, "compile_replicated"): return _compile_replicated_mesh_executable_from_hlo( @@ -3308,28 +3038,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk): _check_aval(v.aval, what_thunk) -def _make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes): - mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()] - sharding = [_UNSHARDED_INSTANCE] * num_dimensions - next_sharded_axis = 0 - # NOTE: sorted is stable, which is important when multiple resources - # map to the same axis. - for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]): - chunked = sharding[axis] - if isinstance(chunked, NoSharding): - chunked = Chunked([]) - sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]]) - assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \ - "Value mapped to the same mesh axis twice" - mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis) - next_sharded_axis += 1 - return ShardingSpec(sharding, mesh_mapping) - - -def new_mesh_sharding_specs(axis_sizes, axis_names): - mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} - return partial(_make_sharding_spec, axis_sizes, mesh_axis_pos) - def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False): mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} @@ -3350,7 +3058,8 @@ def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False): 'axis size should be zero but got ' f'{aval_shape[axis] % axis_sizes[name]}') aval_shape[axis] //= axis_sizes[name] - return _make_sharding_spec(axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes) + return sharding_specs.make_sharding_spec( + axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes) return mk_sharding_spec @@ -3367,3 +3076,13 @@ def device_put(x, devices: Sequence[xc.ArrayImpl], return [jax.device_put(x, device) for device in devices] else: return [jax.device_put(val, device) for val, device in safe_zip(x, devices)] + +# TODO(phawkins): fix external users not to use these functions. +def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None): + return sharding_specs.create_pmap_sharding_spec( + aval.shape, sharded_dim, sharded_dim_size) + +def _pmap_sharding_spec(nrep, axis_size, npart, parts, + sharded_aval, map_axis: Optional[int]) -> ShardingSpec: + return sharding_specs.pmap_sharding_spec(nrep, axis_size, npart, parts, + sharded_aval.shape, map_axis) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index dfc458722..bf6270a4d 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp +from jax._src import sharding_specs from jax._src import typing from jax._src.api import jit, vmap from jax._src.config import config @@ -293,8 +294,8 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla): return sharding elif isinstance(sharding, PmapSharding): key_shape = aval.dtype.impl.key_shape - trailing_sharding = [pxla.NoSharding()] * len(key_shape) - phys_sharding_spec = pxla.ShardingSpec( + trailing_sharding = [sharding_specs.NoSharding()] * len(key_shape) + phys_sharding_spec = sharding_specs.ShardingSpec( sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), mesh_mapping=sharding.sharding_spec.mesh_mapping) return PmapSharding(devices=sharding.devices, diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index c8d5665e1..b389566c1 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -20,15 +20,14 @@ import operator as op from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set, FrozenSet, Union, cast) -from jax._src import core from jax._src import mesh as mesh_lib from jax._src import op_shardings from jax._src import sharding +from jax._src import sharding_specs from jax._src import xla_bridge from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.interpreters import mlir -from jax._src.interpreters import pxla import numpy as np @@ -296,7 +295,7 @@ class NamedSharding(XLACompatibleSharding): array_mapping = get_array_mapping(self._parsed_pspec) # TODO(yashkatariya): Move away from sharding spec in NamedSharding # since we don't really need sharding spec. - sharding_spec = pxla.new_mesh_sharding_specs( + sharding_spec = sharding_specs.new_mesh_sharding_specs( self.mesh.shape, self.mesh.axis_names)(num_dimensions, array_mapping) # Used in `with_sharding_constraint`. special_axes = {} @@ -369,11 +368,11 @@ class SingleDeviceSharding(XLACompatibleSharding): @use_cpp_class(xc.PmapSharding) class PmapSharding(XLACompatibleSharding): devices: np.ndarray - sharding_spec: pxla.ShardingSpec + sharding_spec: sharding_specs.ShardingSpec @use_cpp_method() def __init__(self, devices: Union[Sequence[Device], np.ndarray], - sharding_spec: pxla.ShardingSpec): + sharding_spec: sharding_specs.ShardingSpec): self.devices = np.asarray(devices) # The sharding spec should be pmap's sharding spec. self.sharding_spec = sharding_spec @@ -421,12 +420,12 @@ class PmapSharding(XLACompatibleSharding): """ # The dtype doesn't matter here. Its only used for creating the # sharding_spec. - aval = core.ShapedArray(shape, np.int32) - sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim) + sharding_spec = sharding_specs.create_pmap_sharding_spec( + tuple(shape), sharded_dim) num_ways_sharded = None for s in sharding_spec.sharding: - if isinstance(s, pxla.Unstacked): + if isinstance(s, sharding_specs.Unstacked): num_ways_sharded = s.size if num_ways_sharded is None: raise NotImplementedError( @@ -444,7 +443,7 @@ class PmapSharding(XLACompatibleSharding): @functools.lru_cache(maxsize=4096) def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: self.shard_shape(global_shape) # raises a good error message - indices = pxla.spec_to_indices(global_shape, self.sharding_spec) + indices = sharding_specs.spec_to_indices(global_shape, self.sharding_spec) return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type] @functools.cached_property @@ -459,7 +458,7 @@ class PmapSharding(XLACompatibleSharding): sharded_dim = None sharded_dim_size = None for i, s in enumerate(self.sharding_spec.sharding): - if isinstance(s, pxla.Unstacked): + if isinstance(s, sharding_specs.Unstacked): sharded_dim = i sharded_dim_size = s.size break diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py new file mode 100644 index 000000000..867c7c418 --- /dev/null +++ b/jax/_src/sharding_specs.py @@ -0,0 +1,346 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A ShardingSpec describes at a high level how a logical array is sharded across +# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also +# describe how to shard inputs to a parallel computation). spec_to_indices() +# encodes exactly how a given ShardingSpec is translated to device buffers, i.e. +# how the sharded array is "laid out" across devices. Given a sequence of +# devices, we shard the data across the devices in row-major order, with +# replication treated as an extra inner dimension. +# +# For example, given the logical data array [1, 2, 3, 4], if we were to +# partition this array 4 ways with a replication factor of 2, for a total of 8 +# devices, the data on each device would be: [1, 1], [2, 2], [3, 3], [4, 4]. +# +# This encoding is assumed by various parts of the system, e.g. generating +# replica groups for collective operations. + +import collections +import functools +import itertools +import math +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast + +import numpy as np + +from jax._src import op_shardings +from jax._src import util +from jax._src.lib import pmap_lib +from jax._src.lib import xla_client as xc + +unsafe_map, map = map, util.safe_map + +NoSharding = pmap_lib.NoSharding +Chunked = pmap_lib.Chunked +Unstacked = pmap_lib.Unstacked + +_UNSHARDED_INSTANCE = NoSharding() + +ShardedAxis = pmap_lib.ShardedAxis +Replicated = pmap_lib.Replicated +MeshDimAssignment = Union[ShardedAxis, Replicated] + +ShardingSpec = pmap_lib.ShardingSpec + +OpShardingType = Any + + + +def _sharding_spec_mesh_shape(self): + sharded_axis_sizes = [] + for sharding in self.sharding: + if isinstance(sharding, NoSharding): + continue + elif isinstance(sharding, Unstacked): + sharded_axis_sizes.append(sharding.size) + elif isinstance(sharding, Chunked): + sharded_axis_sizes.extend(sharding.chunks) + else: + util.assert_unreachable(sharding) + return tuple(sharded_axis_sizes[a.axis] if isinstance(a, ShardedAxis) + else a.replicas + for a in self.mesh_mapping) + + +def get_logical_mesh_ids(mesh_shape): + return np.arange(np.prod(mesh_shape)).reshape(mesh_shape) + + +_MeshAxisName = Any + +def sharding_spec_sharding_proto( + self, special_axes: Mapping[int, OpShardingType] = {}): + """Converts a ShardingSpec to an OpSharding proto. + + See + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601 + for details on the OpSharding proto. + Unfortunately the semantics are not very well described in the proto spec, but + the code here might help: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py + """ + mesh_shape = cast(Tuple[int, ...], self.mesh_shape) + mesh = get_logical_mesh_ids(self.mesh_shape) + + sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped + replicated_maxes = [] # lists mesh axis identifiers to replicate over + for maxis, assignment in enumerate(self.mesh_mapping): + if isinstance(assignment, Replicated): + replicated_maxes.append((maxis, assignment.replicas)) + elif isinstance(assignment, ShardedAxis): + sharded_axes[assignment.axis] = maxis + else: + util.assert_unreachable(assignment) + + proto = xc.OpSharding() + if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes: + proto.type = xc.OpSharding.Type.REPLICATED + return proto + else: + proto.type = xc.OpSharding.Type.OTHER + + mesh_permutation = [] + new_mesh_shape = [] + next_sharded_axis = 0 + for axis, sharding in enumerate(self.sharding): + if isinstance(sharding, NoSharding): + new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over + elif isinstance(sharding, Chunked): + for nchunks in sharding.chunks: + maxis = sharded_axes[next_sharded_axis] + assert mesh_shape[maxis] == nchunks + mesh_permutation.append(maxis) + next_sharded_axis += 1 + new_mesh_shape.append(int(np.prod(sharding.chunks))) + elif isinstance(sharding, Unstacked): + raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding") + else: + util.assert_unreachable(sharding) + + # Create a partial sharding proto if tensor is replicated or partitioned + # specially over some mesh axes. + if replicated_maxes: + last_tile_dims = [] + axes_by_type: Dict[OpShardingType, List[_MeshAxisName]] = {} + size_by_type: Dict[OpShardingType, int] = collections.defaultdict(lambda: 1) + assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes.keys())) + for axis, size in replicated_maxes: + ty = special_axes.get(axis, xc.OpSharding.Type.REPLICATED) + axes_by_type.setdefault(ty, []).append(axis) + size_by_type[ty] *= size + for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value): + last_tile_dims.append(ty) + new_mesh_shape.append(size_by_type[ty]) + mesh_permutation.extend(axes) + proto.last_tile_dims = last_tile_dims + + proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape) + proto.tile_assignment_dimensions = list(proto_mesh.shape) + proto.tile_assignment_devices = list(proto_mesh.flat) + return proto + + +def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: + """Returns NumPy-style indices corresponding to a sharding spec. + + Args: + shape: The shape of the logical array being sharded. + + Returns: + An ndarray with the same shape as the logical mesh (as derived form + `mesh_mapping`). Each entry is a NumPy-style index selecting the subset of + the data array to be placed on a corresponding device. The indices can be + ints, slice objects with step=1, or tuples of those. + """ + assert len(shape) == len(self.sharding), (shape, self.sharding) + + has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding) + # Take the op sharding indices generation route for pjit/xmap cases. + if not has_unstacked: + op_sharding_proto = sharding_spec_sharding_proto(self) + return op_shardings.op_sharding_to_numpy_indices( + op_sharding_proto, shape, math.prod(self.mesh_shape) + ).reshape(self.mesh_shape) + + axis_indices: List[Sequence[Index]] = [] + shard_indices_shape = [] + for dim, sharding in enumerate(self.sharding): + axis_size = shape[dim] + if isinstance(sharding, NoSharding): + axis_indices.append([slice(None)]) + # NOTE: We don't append unsharded dimensions to shard_indices_shape here, + # because they do not appear in the mesh mapping. + elif isinstance(sharding, Unstacked): + assert axis_size == sharding.size, f'{axis_size} != {sharding.size}' + axis_indices.append(range(axis_size)) + shard_indices_shape.append(axis_size) + elif isinstance(sharding, Chunked): + total_chunks = int(np.prod(sharding.chunks)) + shard_size, ragged = divmod(axis_size, total_chunks) + assert not ragged, (axis_size, total_chunks, dim) + axis_indices.append([slice(i * shard_size, (i + 1) * shard_size) + for i in range(total_chunks)]) + shard_indices_shape.extend(sharding.chunks) + else: + util.assert_unreachable(sharding) + + # shard_indices is an ndarray representing the sharded axes of the logical array, + # with each dimension having size equal to the number of shards across the corresponding + # logical array dimension, and each element containing the multi-dimensional index that + # is used to extract the corresponding shard of the logical array. + shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_) + for i, idxs in enumerate(itertools.product(*axis_indices)): + shard_indices[i] = idxs + shard_indices = shard_indices.reshape(shard_indices_shape) + + # Ensure that each sharded axis is used exactly once in the mesh mapping + num_sharded_dim = len(shard_indices_shape) + sharded_dim_perm = [a.axis for a in self.mesh_mapping if isinstance(a, ShardedAxis)] + assert (set(sharded_dim_perm) == set(range(num_sharded_dim)) and + len(sharded_dim_perm) == num_sharded_dim) + # Replicate/reorder the indices according to the mesh mapping + replica_sizes = tuple(a.replicas for a in self.mesh_mapping if isinstance(a, Replicated)) + replica_dim, sharded_dim = itertools.count(0), iter(sharded_dim_perm) + perm = [next(replica_dim) if isinstance(a, Replicated) else + len(replica_sizes) + next(sharded_dim) + for a in self.mesh_mapping] + return (np.broadcast_to(shard_indices, replica_sizes + shard_indices.shape) + .transpose(perm)) + +def _sharding_spec_repr(self): + return f'ShardingSpec({self.sharding}, {self.mesh_mapping})' + + +ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape) +ShardingSpec.sharding_proto = sharding_spec_sharding_proto +ShardingSpec.indices = _sharding_spec_indices +# mypy raises: error: Cannot assign to a method [assignment] +ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore + + +Index = Union[int, slice, Tuple[Union[int, slice], ...]] + +def spec_to_indices(shape: Sequence[int], + spec: ShardingSpec) -> Tuple[Index, ...]: + """Returns numpy-style indices corresponding to a sharding spec. + + Each index describes a shard of the array. The order of the indices is the + same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out + row-major). + + Args: + shape: The shape of the logical array being sharded. + spec: Describes how the array is sharded and how the shards are assigned to + the logical mesh. + + Returns: + A tuple of length equal to the size of the mesh (inferred as the product of + sharded dimension sizes and all replication factors). Each element is an + int, a slice object with step=1, or a tuple thereof, to be treated as an + index into the full logical array. + """ + return tuple(spec.indices(shape).flat) # type: ignore + + +def partitioned_sharding_spec(num_partitions: int, + partitions: Optional[Sequence[int]], + shape: Sequence[int]) -> ShardingSpec: + if partitions is None: + maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),) + return ShardingSpec( + sharding=[_UNSHARDED_INSTANCE] * len(shape), + mesh_mapping=maybe_replicate) + else: + assert len(partitions) == len(shape) + return ShardingSpec( + # Chunked expects a list of integers + sharding=map(Chunked, [[x] for x in partitions]), + mesh_mapping=map(ShardedAxis, range(len(partitions)))) + + +def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes): + mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()] + sharding = [_UNSHARDED_INSTANCE] * num_dimensions + next_sharded_axis = 0 + # NOTE: sorted is stable, which is important when multiple resources + # map to the same axis. + for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]): + chunked = sharding[axis] + if isinstance(chunked, NoSharding): + chunked = Chunked([]) + sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]]) + assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \ + "Value mapped to the same mesh axis twice" + mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis) + next_sharded_axis += 1 + return ShardingSpec(sharding, mesh_mapping) + + +def new_mesh_sharding_specs(axis_sizes, axis_names): + mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} + return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos) + +def pmap_sharding_spec(nrep, axis_size, npart, parts, + sharded_shape: Sequence[int], + map_axis: Optional[int]) -> ShardingSpec: + """Sharding spec for arguments or results of a pmap. + Args: + nrep: number of local XLA replicas (product of local axis sizes) + axis_size: local axis size for outer pmap + npart: total number of XLA partitions (required by sharded_jit calls) + parts: the partitioning of the value or None + sharded_aval: the aval of the value inside the outer pmap, an instance of + a ShapedArray. + map_axis: the axis along which the value is mapped in the outer pmap + Returns: + A ShardingSpec. + """ + replication_factor, ragged = divmod(nrep, axis_size) + assert not ragged + # get the sharding spec from inner sharded_jits as if we weren't in a pmap + pspec = partitioned_sharding_spec(npart, parts, sharded_shape) + maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),) + if map_axis is not None: + sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis]) + def shift_sharded_axis(a: MeshDimAssignment): + if isinstance(a, ShardedAxis) and a.axis >= sharded_in_axis: + return ShardedAxis(a.axis + 1) + return a + # replication_factor represents the product of inner pmaps, so it goes + # after the outer pmapped axis at index 0 + return ShardingSpec( + sharding=util.tuple_insert( + pspec.sharding, map_axis, Unstacked(axis_size)), + mesh_mapping=itertools.chain( + [ShardedAxis(sharded_in_axis)], maybe_replicate, + map(shift_sharded_axis, pspec.mesh_mapping))) + else: + return ShardingSpec( + sharding=pspec.sharding, + mesh_mapping=(Replicated(axis_size),) + maybe_replicate + pspec.mesh_mapping) + + +def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0, + sharded_dim_size: Optional[int] = None): + if sharded_dim is not None: + sharded_shape = shape[:sharded_dim] + shape[sharded_dim+1:] + if sharded_dim_size is None: + sharded_dim_size = shape[sharded_dim] + else: + assert sharded_dim_size is not None + sharded_shape = shape + + return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None, + sharded_shape, sharded_dim) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index f70889c5b..b0d215720 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -17,7 +17,6 @@ from jax._src.interpreters.pxla import ( ArrayMapping as ArrayMapping, ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified, AvalDimSharding as AvalDimSharding, - Chunked as Chunked, EmapInfo as EmapInfo, ExecuteReplicated as ExecuteReplicated, Index as Index, @@ -27,8 +26,6 @@ from jax._src.interpreters.pxla import ( MeshComputation as MeshComputation, MeshDimAssignment as MeshDimAssignment, MeshExecutable as MeshExecutable, - NoSharding as NoSharding, - OpShardingType as OpShardingType, OrderedDictType as OrderedDictType, ParallelCallableInfo as ParallelCallableInfo, PartitionInfo as PartitionInfo, @@ -37,18 +34,14 @@ from jax._src.interpreters.pxla import ( PmapExecutable as PmapExecutable, PxlaResultHandler as PxlaResultHandler, ReplicaInfo as ReplicaInfo, - Replicated as Replicated, ResultsHandler as ResultsHandler, SPMDBatchTrace as SPMDBatchTrace, ShardInfo as ShardInfo, - ShardedAxis as ShardedAxis, - ShardingSpec as ShardingSpec, TileManual as TileManual, TileVectorize as TileVectorize, TilingMethod as TilingMethod, UnloadedMeshExecutable as UnloadedMeshExecutable, UnloadedPmapExecutable as UnloadedPmapExecutable, - Unstacked as Unstacked, WeakRefList as WeakRefList, _UNSPECIFIED as _UNSPECIFIED, _create_pmap_sharding_spec as _create_pmap_sharding_spec, @@ -76,9 +69,7 @@ from jax._src.interpreters.pxla import ( maybe_extend_axis_env as maybe_extend_axis_env, mesh_sharding_specs as mesh_sharding_specs, multi_host_supported_collectives as multi_host_supported_collectives, - new_mesh_sharding_specs as new_mesh_sharding_specs, parallel_callable as parallel_callable, - partitioned_sharding_spec as partitioned_sharding_spec, reconcile_num_partitions as reconcile_num_partitions, replicate as replicate, resource_typecheck as resource_typecheck, @@ -88,8 +79,6 @@ from jax._src.interpreters.pxla import ( shard_aval as shard_aval, shard_aval_handlers as shard_aval_handlers, shard_to_full_p as shard_to_full_p, - sharding_spec_sharding_proto as sharding_spec_sharding_proto, - spec_to_indices as spec_to_indices, spmd_primitive_batchers as spmd_primitive_batchers, stage_parallel_callable as stage_parallel_callable, tile_aval_nd as tile_aval_nd, @@ -114,6 +103,19 @@ from jax._src.op_shardings import ( op_sharding_to_indices as op_sharding_to_indices, ) +from jax._src.sharding_specs import ( + Chunked as Chunked, + NoSharding as NoSharding, + OpShardingType as OpShardingType, + Replicated as Replicated, + ShardedAxis as ShardedAxis, + ShardingSpec as ShardingSpec, + Unstacked as Unstacked, + new_mesh_sharding_specs as new_mesh_sharding_specs, + sharding_spec_sharding_proto as sharding_spec_sharding_proto, + spec_to_indices as spec_to_indices, +) + # Deprecations from jax._src.mesh import Mesh as _deprecated_Mesh diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 7ad4a40d9..725f8311b 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -40,10 +40,10 @@ from jax._src.lax import parallel from jax._src import api as src_api from jax import random from jax._src import core -from jax._src.core import ShapedArray from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr, linearize, device_put) from jax._src import config as jax_config +from jax._src import sharding_specs from jax._src import xla_bridge from jax._src.lib import xla_extension from jax._src.lib import xla_extension_version @@ -118,13 +118,11 @@ ignore_xmap_warning = partial( def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None, devices=None, sharded_dim_size=None): - dtype = np.int32 - aval = ShapedArray(input_shape, dtype) - if input_data is None: input_data = np.arange(math.prod(input_shape)).reshape(input_shape) - sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size) + sharding_spec = sharding_specs.create_pmap_sharding_spec( + input_shape, in_axes, sharded_dim_size) if devices is None: devices = jax.devices() @@ -2829,7 +2827,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (4, 8) spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])), mesh_mapping=map(pxla.ShardedAxis, (0, 1))) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((slice(0,2), slice(0,4)), (slice(0,2), slice(4,8)), (slice(2,4), slice(0,4)), @@ -2839,7 +2837,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (4, 8) spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])), mesh_mapping=map(pxla.ShardedAxis, (1, 0))) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((slice(0,2), slice(0,4)), (slice(2,4), slice(0,4)), (slice(0,2), slice(4,8)), @@ -2851,7 +2849,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): mesh_mapping=(pxla.Replicated(2), pxla.ShardedAxis(1), pxla.ShardedAxis(0))) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((slice(0,2), slice(0,4)), (slice(2,4), slice(0,4)), (slice(0,2), slice(4,8)), @@ -2861,7 +2859,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (4, 8) spec = pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()), mesh_mapping=(pxla.ShardedAxis(0),)) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((slice(0,2), slice(None)), (slice(2,4), slice(None)))) @@ -2869,14 +2867,14 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (4, 8) spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()), mesh_mapping=()) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((slice(None), slice(None)),)) def testUnmaterializedAxis(self): shape = (4, 8) spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(4), pxla.NoSharding()), mesh_mapping=(pxla.ShardedAxis(0),)) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((0, slice(None)), (1, slice(None)), (2, slice(None)), @@ -2885,7 +2883,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (2, 2) spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.Unstacked(2)), mesh_mapping=(pxla.ShardedAxis(0),)) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((slice(None), 0), (slice(None), 1))) @@ -2893,14 +2891,14 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (2, 8) spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()), mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3))) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), tuple([(0, slice(None))] * 3 + [(1, slice(None))] * 3)) def testReplicationPosition2(self): shape = (2, 8) spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])), mesh_mapping=(pxla.ShardedAxis(0), pxla.ShardedAxis(1), pxla.Replicated(3))) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((0, slice(0, 4)), (0, slice(0, 4)), (0, slice(0, 4)), (0, slice(4, 8)), (0, slice(4, 8)), (0, slice(4, 8)), (1, slice(0, 4)), (1, slice(0, 4)), (1, slice(0, 4)), @@ -2910,7 +2908,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (2, 8) spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])), mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3), pxla.ShardedAxis(1))) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((0, slice(0, 4)), (0, slice(4, 8)), (0, slice(0, 4)), (0, slice(4, 8)), (0, slice(0, 4)), (0, slice(4, 8)), @@ -2922,7 +2920,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = (2, 8) spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()), mesh_mapping=(pxla.Replicated(3), pxla.ShardedAxis(0))) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), tuple([(0, slice(None)), (1, slice(None))] * 3)) def testMultipleReplications(self): @@ -2933,7 +2931,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): pxla.ShardedAxis(0), pxla.Replicated(2), pxla.ShardedAxis(1))) self.assertEqual( - pxla.spec_to_indices(shape, spec), + sharding_specs.spec_to_indices(shape, spec), ((0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)), (0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)), (1, slice(None), slice(0, 2)), (1, slice(None), slice(2, 4)), @@ -2943,7 +2941,7 @@ class SpecToIndicesTest(jtu.JaxTestCase): shape = () spec = pxla.ShardingSpec(sharding=(), mesh_mapping=(pxla.Replicated(3),)) - self.assertEqual(pxla.spec_to_indices(shape, spec), + self.assertEqual(sharding_specs.spec_to_indices(shape, spec), ((), (), ())) @@ -3003,7 +3001,7 @@ class ShardArgsTest(jtu.JaxTestCase): mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))], ]) def testShardArgs(self, shape, spec, make_arg): - indices = pxla.spec_to_indices(shape, spec) + indices = sharding_specs.spec_to_indices(shape, spec) nshards = len(indices) if jax.device_count() < nshards: raise SkipTest @@ -3014,7 +3012,8 @@ class ShardArgsTest(jtu.JaxTestCase): sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) else: sharding = jax.sharding.GSPMDSharding( - jax.devices()[:nshards], pxla.sharding_spec_sharding_proto(spec)) + jax.devices()[:nshards], + sharding_specs.sharding_spec_sharding_proto(spec)) results = pxla.shard_args( jax.devices()[:nshards], [indices], [sharding], [arg]