diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index a63c2268c..62eddfd0a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -69,7 +69,7 @@ from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, - is_unspecified, is_unspecified_or_auto + is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources ) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_delete, distributed_debug_log, @@ -1510,12 +1510,15 @@ def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, # TODO: Can we short-circuit for replicated values? Probably not. aval_in, = ctx.avals_in aval_out, = ctx.avals_out - sharding_proto = mesh_sharding_specs( - mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto().to_proto() + sharding_proto = ( + sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes)) + ._to_xla_hlo_sharding(aval_in.ndim).to_proto()) unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, unspecified_dims=unspecified_dims) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, + unspecified_dims=unspecified_dims) proto = manual_proto(aval_in, manual_axes, mesh) - return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, unspecified_dims=unspecified_dims), + return (mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, + unspecified_dims=unspecified_dims),) shard_to_full_p = core.Primitive('shard_to_full') @@ -1531,10 +1534,13 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp aval_out, = ctx.avals_out proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims) - sharding_proto = mesh_sharding_specs( - mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto().to_proto() - return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims), + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, + unspecified_dims=unspecified_dims) + sharding_proto = ( + sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes)) + ._to_xla_hlo_sharding(aval_out.ndim).to_proto()) + return (mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, + unspecified_dims),) @lu.transformation def vtile_manual(manual_axes: frozenset[sharding_impls.MeshAxisName], @@ -3068,30 +3074,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk): _check_aval(v.aval, what_thunk) -def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False): - mesh_axis_pos = {name: i for i, name in enumerate(axis_names)} - # NOTE: This takes in the non-sharded avals! - def mk_sharding_spec(aval, aval_axes): - if aval is core.abstract_token: - assert not aval_axes - return ShardingSpec([], [Replicated(axis_size) for axis_size in axis_sizes.values()]) - aval_shape = list(aval.shape) - # NOTE: sorted is stable, which is important when multiple resources - # map to the same axis. - for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]): - if not allow_uneven_axes: - if aval_shape[axis] % axis_sizes[name] != 0: - raise ValueError( - f'The aval shape on dimension {axis} is {aval_shape[axis]} and ' - f'the size of axis {name} is {axis_sizes[name]}. The aval shape % ' - 'axis size should be zero but got ' - f'{aval_shape[axis] % axis_sizes[name]}') - aval_shape[axis] //= axis_sizes[name] - return sharding_specs.make_sharding_spec( - axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes) - return mk_sharding_spec - - @contextmanager def maybe_extend_axis_env(*args, **kwargs): with core.extend_axis_env(*args, **kwargs): diff --git a/jax/_src/maps.py b/jax/_src/maps.py index ade9fba87..86b8b2d1d 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -1417,11 +1417,11 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, global_in_avals = ctx.avals_in vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals) - global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names) sharded_global_in_nodes = [ [mlir.wrap_with_sharding_op( ctx, node, aval, - global_sharding_spec(aval, aval_axes).sharding_proto().to_proto())] + NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes) + )._to_xla_hlo_sharding(aval.ndim).to_proto())] if aval_axes else [node] for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes) ] @@ -1441,7 +1441,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, sharded_global_out_nodes = [ mlir.wrap_with_sharding_op( ctx, node, aval, - global_sharding_spec(aval, aval_axes).sharding_proto().to_proto()) + NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes) + )._to_xla_hlo_sharding(aval.ndim).to_proto()) if aval_axes else node for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes) ] diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index edd2f3710..938ebe868 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -251,24 +251,6 @@ def spec_to_indices(shape: Sequence[int], return tuple(spec.indices(shape).flat) # type: ignore -def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes): - mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()] - sharding = [_UNSHARDED_INSTANCE] * num_dimensions - next_sharded_axis = 0 - # NOTE: sorted is stable, which is important when multiple resources - # map to the same axis. - for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]): - chunked = sharding[axis] - if isinstance(chunked, NoSharding): - chunked = Chunked([]) - sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]]) - assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \ - "Value mapped to the same mesh axis twice" - mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis) - next_sharded_axis += 1 - return ShardingSpec(sharding, mesh_mapping) - - def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int], map_axis: int | None) -> ShardingSpec: """Sharding spec for arguments or results of a pmap. diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 9f8f57883..a515f2293 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -22,7 +22,6 @@ from jax._src.interpreters.pxla import ( global_aval_to_result_handler as global_aval_to_result_handler, global_avals_to_results_handler as global_avals_to_results_handler, global_result_handlers as global_result_handlers, - mesh_sharding_specs as mesh_sharding_specs, parallel_callable as parallel_callable, shard_arg as shard_arg, shard_args as shard_args, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6bc4ff898..a3f974f69 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4281,17 +4281,6 @@ class UtilTest(jtu.JaxTestCase): self.assertTrue(all(i == (slice(None),) * aval.ndim for out, aval in safe_zip(out_indices, in_avals) for i in out)) - def test_mesh_sharding_spec(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - array_mapping = pxla.get_array_mapping(P('x', 'y')) - aval = core.ShapedArray((1, 1), jnp.int32) - with self.assertRaisesRegex( - ValueError, - 'The aval shape on dimension 0 is 1 and the size of axis x is 4. The ' - 'aval shape % axis size should be zero but got 1' - ): - pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, array_mapping) - @parameterized.named_parameters( ("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError), ("only_unspecified", UNSPECIFIED),