mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Delete sharding spec to HloSharding conversion since it's not used anymore.
PiperOrigin-RevId: 595192496
This commit is contained in:
parent
fff5ea579a
commit
c0d4653fc9
@ -2475,6 +2475,10 @@ def _get_layouts_from_executable(
|
||||
return new_in_layouts, new_out_layouts # type: ignore
|
||||
|
||||
|
||||
def get_logical_mesh_ids(mesh_shape):
|
||||
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering,
|
||||
@ -2528,7 +2532,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
assert mesh is not None
|
||||
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
|
||||
opts.auto_spmd_partitioning_mesh_ids = (
|
||||
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
|
||||
get_logical_mesh_ids(list(mesh.shape.values()))
|
||||
.reshape(-1))
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
|
||||
|
@ -29,18 +29,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Sequence
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Union, cast
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import op_shardings
|
||||
from jax._src import util
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
unsafe_map, map = map, util.safe_map
|
||||
|
||||
@ -56,9 +53,6 @@ MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
|
||||
ShardingSpec = pmap_lib.ShardingSpec
|
||||
|
||||
OpShardingType = Any
|
||||
|
||||
|
||||
|
||||
def _sharding_spec_mesh_shape(self):
|
||||
sharded_axis_sizes = []
|
||||
@ -76,79 +70,6 @@ def _sharding_spec_mesh_shape(self):
|
||||
for a in self.mesh_mapping)
|
||||
|
||||
|
||||
def get_logical_mesh_ids(mesh_shape):
|
||||
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
||||
|
||||
|
||||
_MeshAxisName = Any
|
||||
|
||||
def sharding_spec_sharding_proto(
|
||||
self, special_axes: Mapping[int, OpShardingType] | None = None
|
||||
) -> xc.HloSharding:
|
||||
"""Converts a ShardingSpec to an OpSharding proto.
|
||||
|
||||
See
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
|
||||
for details on the OpSharding proto.
|
||||
Unfortunately the semantics are not very well described in the proto spec, but
|
||||
the code here might help:
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
|
||||
"""
|
||||
special_axes_dict = {} if special_axes is None else special_axes
|
||||
mesh_shape = cast(tuple[int, ...], self.mesh_shape)
|
||||
|
||||
sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
|
||||
replicated_maxes = [] # lists mesh axis identifiers to replicate over
|
||||
for maxis, assignment in enumerate(self.mesh_mapping):
|
||||
if isinstance(assignment, Replicated):
|
||||
replicated_maxes.append((maxis, assignment.replicas))
|
||||
elif isinstance(assignment, ShardedAxis):
|
||||
sharded_axes[assignment.axis] = maxis
|
||||
else:
|
||||
util.assert_unreachable(assignment)
|
||||
|
||||
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes_dict:
|
||||
return xc.HloSharding.replicate()
|
||||
|
||||
mesh_permutation = []
|
||||
new_mesh_shape = []
|
||||
next_sharded_axis = 0
|
||||
for axis, sharding in enumerate(self.sharding):
|
||||
if isinstance(sharding, NoSharding):
|
||||
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
|
||||
elif isinstance(sharding, Chunked):
|
||||
for nchunks in sharding.chunks:
|
||||
maxis = sharded_axes[next_sharded_axis]
|
||||
assert mesh_shape[maxis] == nchunks
|
||||
mesh_permutation.append(maxis)
|
||||
next_sharded_axis += 1
|
||||
new_mesh_shape.append(math.prod(sharding.chunks))
|
||||
elif isinstance(sharding, Unstacked):
|
||||
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
|
||||
else:
|
||||
util.assert_unreachable(sharding)
|
||||
|
||||
# Create a partial sharding proto if tensor is replicated or partitioned
|
||||
# specially over some mesh axes.
|
||||
last_tile_dims = []
|
||||
if replicated_maxes:
|
||||
axes_by_type: dict[OpShardingType, list[_MeshAxisName]] = {}
|
||||
size_by_type: dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
|
||||
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes_dict.keys()))
|
||||
for axis, size in replicated_maxes:
|
||||
ty = special_axes_dict.get(axis, xc.OpSharding.Type.REPLICATED)
|
||||
axes_by_type.setdefault(ty, []).append(axis)
|
||||
size_by_type[ty] *= size
|
||||
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
|
||||
last_tile_dims.append(ty)
|
||||
new_mesh_shape.append(size_by_type[ty])
|
||||
mesh_permutation.extend(axes)
|
||||
|
||||
return xc.HloSharding.iota_tile(
|
||||
dims=new_mesh_shape, reshape_dims=mesh_shape,
|
||||
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)
|
||||
|
||||
|
||||
def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
|
||||
"""Returns NumPy-style indices corresponding to a sharding spec.
|
||||
|
||||
@ -163,14 +84,6 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
|
||||
"""
|
||||
assert len(shape) == len(self.sharding), (shape, self.sharding)
|
||||
|
||||
has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
|
||||
# Take the op sharding indices generation route for pjit/xmap cases.
|
||||
if not has_unstacked:
|
||||
hlo_sharding = sharding_spec_sharding_proto(self)
|
||||
return op_shardings.op_sharding_to_numpy_indices(
|
||||
hlo_sharding, shape, math.prod(self.mesh_shape)
|
||||
).reshape(self.mesh_shape)
|
||||
|
||||
axis_indices: list[Sequence[Index]] = []
|
||||
shard_indices_shape = []
|
||||
for dim, sharding in enumerate(self.sharding):
|
||||
@ -221,7 +134,6 @@ def _sharding_spec_repr(self):
|
||||
|
||||
|
||||
ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape)
|
||||
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
|
||||
ShardingSpec.indices = _sharding_spec_indices
|
||||
# mypy raises: error: Cannot assign to a method [assignment]
|
||||
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
|
||||
|
@ -164,12 +164,13 @@ class PickleTest(jtu.JaxTestCase):
|
||||
self.assertEqual(pickle.loads(pickle.dumps(sharding)), sharding)
|
||||
|
||||
def testPickleOpSharding(self):
|
||||
sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
||||
(pxla.ShardedAxis(0), pxla.ShardedAxis(1)))
|
||||
op_sharding = sharding.sharding_proto().to_proto()
|
||||
op = xc.OpSharding()
|
||||
op.type = xc.OpSharding.Type.OTHER
|
||||
op.tile_assignment_dimensions = [4, 2]
|
||||
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
self.assertTrue(
|
||||
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op_sharding))),
|
||||
xc.HloSharding.from_proto(op_sharding))
|
||||
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op))),
|
||||
xc.HloSharding.from_proto(op))
|
||||
|
||||
def test_pickle_single_device_sharding(self):
|
||||
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
|
@ -2988,25 +2988,12 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
# unsharded
|
||||
[(4, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
|
||||
mesh_mapping=())],
|
||||
# partitioned, 1 axis
|
||||
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
|
||||
mesh_mapping=(pxla.ShardedAxis(0),))],
|
||||
# partitioned, 2 axes
|
||||
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
|
||||
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))],
|
||||
# partitioned, 2 axes, permuted
|
||||
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
|
||||
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))],
|
||||
# replication + sharding
|
||||
[(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
|
||||
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))],
|
||||
# replication, no sharding
|
||||
[(2, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
|
||||
mesh_mapping=(pxla.Replicated(3),))],
|
||||
# multiple replicated axes
|
||||
[(1, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([1]), pxla.Chunked([2])),
|
||||
mesh_mapping=(pxla.Replicated(2), pxla.ShardedAxis(0),
|
||||
pxla.Replicated(2), pxla.ShardedAxis(1)))],
|
||||
# replicated scalar
|
||||
[(), pxla.ShardingSpec(sharding=(),
|
||||
mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))],
|
||||
@ -3018,14 +3005,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
raise SkipTest
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
arg = make_arg(x)
|
||||
sharding = None
|
||||
if any(isinstance(s, pxla.Unstacked) for s in spec.sharding):
|
||||
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
|
||||
else:
|
||||
sharding = jax.sharding.GSPMDSharding(
|
||||
jax.devices()[:nshards],
|
||||
sharding_specs.sharding_spec_sharding_proto(spec))
|
||||
|
||||
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
|
||||
results = pxla.shard_args(
|
||||
jax.devices()[:nshards], [indices], [sharding], [arg]
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user