Delete mesh_sharding_specs from JAX

PiperOrigin-RevId: 595164505
This commit is contained in:
Yash Katariya 2024-01-02 11:13:57 -08:00 committed by jax authors
parent 697f17adf1
commit 5192fca09b
5 changed files with 19 additions and 66 deletions

View File

@ -69,7 +69,7 @@ from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources
)
from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
@ -1510,12 +1510,15 @@ def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
# TODO: Can we short-circuit for replicated values? Probably not.
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
sharding_proto = mesh_sharding_specs(
mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto().to_proto()
sharding_proto = (
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
._to_xla_hlo_sharding(aval_in.ndim).to_proto())
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, unspecified_dims=unspecified_dims)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto,
unspecified_dims=unspecified_dims)
proto = manual_proto(aval_in, manual_axes, mesh)
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, unspecified_dims=unspecified_dims),
return (mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto,
unspecified_dims=unspecified_dims),)
shard_to_full_p = core.Primitive('shard_to_full')
@ -1531,10 +1534,13 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
aval_out, = ctx.avals_out
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims)
sharding_proto = mesh_sharding_specs(
mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto().to_proto()
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto,
unspecified_dims=unspecified_dims)
sharding_proto = (
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
._to_xla_hlo_sharding(aval_out.ndim).to_proto())
return (mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto,
unspecified_dims),)
@lu.transformation
def vtile_manual(manual_axes: frozenset[sharding_impls.MeshAxisName],
@ -3068,30 +3074,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
_check_aval(v.aval, what_thunk)
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
# NOTE: This takes in the non-sharded avals!
def mk_sharding_spec(aval, aval_axes):
if aval is core.abstract_token:
assert not aval_axes
return ShardingSpec([], [Replicated(axis_size) for axis_size in axis_sizes.values()])
aval_shape = list(aval.shape)
# NOTE: sorted is stable, which is important when multiple resources
# map to the same axis.
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
if not allow_uneven_axes:
if aval_shape[axis] % axis_sizes[name] != 0:
raise ValueError(
f'The aval shape on dimension {axis} is {aval_shape[axis]} and '
f'the size of axis {name} is {axis_sizes[name]}. The aval shape % '
'axis size should be zero but got '
f'{aval_shape[axis] % axis_sizes[name]}')
aval_shape[axis] //= axis_sizes[name]
return sharding_specs.make_sharding_spec(
axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
return mk_sharding_spec
@contextmanager
def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):

View File

@ -1417,11 +1417,11 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
global_in_avals = ctx.avals_in
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
sharded_global_in_nodes = [
[mlir.wrap_with_sharding_op(
ctx, node, aval,
global_sharding_spec(aval, aval_axes).sharding_proto().to_proto())]
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
)._to_xla_hlo_sharding(aval.ndim).to_proto())]
if aval_axes else [node]
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
]
@ -1441,7 +1441,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
sharded_global_out_nodes = [
mlir.wrap_with_sharding_op(
ctx, node, aval,
global_sharding_spec(aval, aval_axes).sharding_proto().to_proto())
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
)._to_xla_hlo_sharding(aval.ndim).to_proto())
if aval_axes else node
for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
]

View File

@ -251,24 +251,6 @@ def spec_to_indices(shape: Sequence[int],
return tuple(spec.indices(shape).flat) # type: ignore
def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
next_sharded_axis = 0
# NOTE: sorted is stable, which is important when multiple resources
# map to the same axis.
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
chunked = sharding[axis]
if isinstance(chunked, NoSharding):
chunked = Chunked([])
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
"Value mapped to the same mesh axis twice"
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
next_sharded_axis += 1
return ShardingSpec(sharding, mesh_mapping)
def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
map_axis: int | None) -> ShardingSpec:
"""Sharding spec for arguments or results of a pmap.

View File

@ -22,7 +22,6 @@ from jax._src.interpreters.pxla import (
global_aval_to_result_handler as global_aval_to_result_handler,
global_avals_to_results_handler as global_avals_to_results_handler,
global_result_handlers as global_result_handlers,
mesh_sharding_specs as mesh_sharding_specs,
parallel_callable as parallel_callable,
shard_arg as shard_arg,
shard_args as shard_args,

View File

@ -4281,17 +4281,6 @@ class UtilTest(jtu.JaxTestCase):
self.assertTrue(all(i == (slice(None),) * aval.ndim
for out, aval in safe_zip(out_indices, in_avals) for i in out))
def test_mesh_sharding_spec(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
array_mapping = pxla.get_array_mapping(P('x', 'y'))
aval = core.ShapedArray((1, 1), jnp.int32)
with self.assertRaisesRegex(
ValueError,
'The aval shape on dimension 0 is 1 and the size of axis x is 4. The '
'aval shape % axis size should be zero but got 1'
):
pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, array_mapping)
@parameterized.named_parameters(
("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError),
("only_unspecified", UNSPECIFIED),