diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index e9f6066b7..bf2e7b5d9 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1634,6 +1634,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): if (isinstance(axis_context, SPMDAxisContext) and axis_context.manual_axes and axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)): + if axis_env.sizes[axis_pos] == 1: + return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32))) x = hlo.iota(ir.RankedTensorType.get( [axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0)) sharding_proto = ( diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index e4f62a324..ec06f2661 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2162,7 +2162,22 @@ class ShardMapTest(jtu.JaxTestCase): mesh, in_specs=P('i', None), out_specs=P('i', None), check_rep=False, auto=frozenset({'j'}))() - self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1)) + self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) + + def test_partial_auto_axis_index_degenerated_axis(self): + if config.use_shardy_partitioner.value: + self.skipTest('Shardy does not support full-to-shard.') + + mesh = jtu.create_mesh((1, 2), ('i', 'j')) + out_sharding = NamedSharding(mesh, P('i', None)) + + @partial(jax.jit, out_shardings=out_sharding) + def f(): + return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1), + mesh, in_specs=P('i', None), out_specs=P('i', None), + check_rep=False, auto=frozenset({'j'}))() + + self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1)) def test_partial_auto_ppermute(self): if xla_extension_version < 302: