Delete sharding spec to HloSharding conversion since it's not used anymore.

PiperOrigin-RevId: 595192496
This commit is contained in:
Yash Katariya 2024-01-02 13:12:44 -08:00 committed by jax authors
parent fff5ea579a
commit c0d4653fc9
4 changed files with 14 additions and 117 deletions

View File

@ -2475,6 +2475,10 @@ def _get_layouts_from_executable(
return new_in_layouts, new_out_layouts # type: ignore
def get_logical_mesh_ids(mesh_shape):
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
@weakref_lru_cache
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering,
@ -2528,7 +2532,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
assert mesh is not None
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
opts.auto_spmd_partitioning_mesh_ids = (
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
get_logical_mesh_ids(list(mesh.shape.values()))
.reshape(-1))
compile_options.parameter_is_tupled_arguments = tuple_args
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)

View File

@ -29,18 +29,15 @@
from __future__ import annotations
import collections
from collections.abc import Mapping, Sequence
from collections.abc import Sequence
import itertools
import math
from typing import Any, Union, cast
from typing import Union
import numpy as np
from jax._src import op_shardings
from jax._src import util
from jax._src.lib import pmap_lib
from jax._src.lib import xla_client as xc
unsafe_map, map = map, util.safe_map
@ -56,9 +53,6 @@ MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = pmap_lib.ShardingSpec
OpShardingType = Any
def _sharding_spec_mesh_shape(self):
sharded_axis_sizes = []
@ -76,79 +70,6 @@ def _sharding_spec_mesh_shape(self):
for a in self.mesh_mapping)
def get_logical_mesh_ids(mesh_shape):
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
_MeshAxisName = Any
def sharding_spec_sharding_proto(
self, special_axes: Mapping[int, OpShardingType] | None = None
) -> xc.HloSharding:
"""Converts a ShardingSpec to an OpSharding proto.
See
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
Unfortunately the semantics are not very well described in the proto spec, but
the code here might help:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
"""
special_axes_dict = {} if special_axes is None else special_axes
mesh_shape = cast(tuple[int, ...], self.mesh_shape)
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
replicated_maxes = [] # lists mesh axis identifiers to replicate over
for maxis, assignment in enumerate(self.mesh_mapping):
if isinstance(assignment, Replicated):
replicated_maxes.append((maxis, assignment.replicas))
elif isinstance(assignment, ShardedAxis):
sharded_axes[assignment.axis] = maxis
else:
util.assert_unreachable(assignment)
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes_dict:
return xc.HloSharding.replicate()
mesh_permutation = []
new_mesh_shape = []
next_sharded_axis = 0
for axis, sharding in enumerate(self.sharding):
if isinstance(sharding, NoSharding):
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
elif isinstance(sharding, Chunked):
for nchunks in sharding.chunks:
maxis = sharded_axes[next_sharded_axis]
assert mesh_shape[maxis] == nchunks
mesh_permutation.append(maxis)
next_sharded_axis += 1
new_mesh_shape.append(math.prod(sharding.chunks))
elif isinstance(sharding, Unstacked):
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
else:
util.assert_unreachable(sharding)
# Create a partial sharding proto if tensor is replicated or partitioned
# specially over some mesh axes.
last_tile_dims = []
if replicated_maxes:
axes_by_type: dict[OpShardingType, list[_MeshAxisName]] = {}
size_by_type: dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes_dict.keys()))
for axis, size in replicated_maxes:
ty = special_axes_dict.get(axis, xc.OpSharding.Type.REPLICATED)
axes_by_type.setdefault(ty, []).append(axis)
size_by_type[ty] *= size
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
last_tile_dims.append(ty)
new_mesh_shape.append(size_by_type[ty])
mesh_permutation.extend(axes)
return xc.HloSharding.iota_tile(
dims=new_mesh_shape, reshape_dims=mesh_shape,
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)
def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
@ -163,14 +84,6 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
"""
assert len(shape) == len(self.sharding), (shape, self.sharding)
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
# Take the op sharding indices generation route for pjit/xmap cases.
if not has_unstacked:
hlo_sharding = sharding_spec_sharding_proto(self)
return op_shardings.op_sharding_to_numpy_indices(
hlo_sharding, shape, math.prod(self.mesh_shape)
).reshape(self.mesh_shape)
axis_indices: list[Sequence[Index]] = []
shard_indices_shape = []
for dim, sharding in enumerate(self.sharding):
@ -221,7 +134,6 @@ def _sharding_spec_repr(self):
ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape)
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
ShardingSpec.indices = _sharding_spec_indices
# mypy raises: error: Cannot assign to a method [assignment]
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore

View File

@ -164,12 +164,13 @@ class PickleTest(jtu.JaxTestCase):
self.assertEqual(pickle.loads(pickle.dumps(sharding)), sharding)
def testPickleOpSharding(self):
sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(0), pxla.ShardedAxis(1)))
op_sharding = sharding.sharding_proto().to_proto()
op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [4, 2]
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
self.assertTrue(
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op_sharding))),
xc.HloSharding.from_proto(op_sharding))
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op))),
xc.HloSharding.from_proto(op))
def test_pickle_single_device_sharding(self):
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])

View File

@ -2988,25 +2988,12 @@ class ShardArgsTest(jtu.JaxTestCase):
# unsharded
[(4, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=())],
# partitioned, 1 axis
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))],
# partitioned, 2 axes
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))],
# partitioned, 2 axes, permuted
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))],
# replication + sharding
[(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))],
# replication, no sharding
[(2, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=(pxla.Replicated(3),))],
# multiple replicated axes
[(1, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([1]), pxla.Chunked([2])),
mesh_mapping=(pxla.Replicated(2), pxla.ShardedAxis(0),
pxla.Replicated(2), pxla.ShardedAxis(1)))],
# replicated scalar
[(), pxla.ShardingSpec(sharding=(),
mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))],
@ -3018,14 +3005,7 @@ class ShardArgsTest(jtu.JaxTestCase):
raise SkipTest
x = np.arange(math.prod(shape)).reshape(shape)
arg = make_arg(x)
sharding = None
if any(isinstance(s, pxla.Unstacked) for s in spec.sharding):
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
else:
sharding = jax.sharding.GSPMDSharding(
jax.devices()[:nshards],
sharding_specs.sharding_spec_sharding_proto(spec))
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
results = pxla.shard_args(
jax.devices()[:nshards], [indices], [sharding], [arg]
)