From 0e7f218eb0d47570c4dd08e3ed9970efc906792b Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 12 Dec 2024 18:38:27 -0800 Subject: [PATCH] Support axis_index inside shard_map(auto=...) by using iota and then calling full_to_shard. PiperOrigin-RevId: 705704369 --- jax/_src/lax/parallel.py | 25 +++++++++++++++++++------ tests/shard_map_test.py | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 6ae2d02f8..46fd82533 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1591,7 +1591,25 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_name, = axis_name if axis_name not in axis_env.names: raise NameError(f"unbound axis name: {axis_name}") + axis_context = ctx.module_context.axis_context axis_pos = list(axis_env.names).index(axis_name) + + # For partial auto, lower using iota. + if (isinstance(axis_context, SPMDAxisContext) and + axis_context.manual_axes and + axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)): + x = hlo.iota(ir.RankedTensorType.get( + [axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0)) + sharding_proto = ( + NamedSharding(axis_context.mesh, P(axis_name)) + ._to_xla_hlo_sharding(1).to_proto()) + aval_in = ShapedArray((axis_env.sizes[axis_pos],), np.int32) + aval_out = ShapedArray((1,), np.int32) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto) + proto = pxla.manual_proto(aval_in, axis_context.manual_axes, axis_context.mesh) + x = mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto) + return hlo.reshape(ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), x) + nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( np.array( @@ -1599,12 +1617,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): ) ) mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) - axis_context = ctx.module_context.axis_context - is_spmd = isinstance( - axis_context, - (SPMDAxisContext, ShardingContext), - ) - if is_spmd: + if isinstance(axis_context, (ShardingContext, SPMDAxisContext)): device_id = hlo.partition_id() else: device_id = hlo.replica_id() diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 6d799b52d..d247c8fdb 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2147,6 +2147,22 @@ class ShardMapTest(jtu.JaxTestCase): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) + def test_partial_auto_axis_index(self): + if config.use_shardy_partitioner.value: + self.skipTest('Shardy does not support full-to-shard.') + + mesh = jtu.create_mesh((4, 2), ('i', 'j')) + out_sharding = NamedSharding(mesh, P('i', None)) + + @partial(jax.jit, out_shardings=out_sharding) + def f(): + return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1), + mesh, in_specs=P('i', None), out_specs=P('i', None), + check_rep=False, auto=frozenset({'j'}))() + + self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1)) + + def test_vmap_grad_shmap_spmd_axis_name_residuals(self): # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j'))