fix the degenerated case

This commit is contained in:
Yunlong Liu 2024-12-31 00:24:02 +00:00
parent e87a2a5929
commit 3ff000ee3e
2 changed files with 18 additions and 1 deletions

View File

@ -1634,6 +1634,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
if (isinstance(axis_context, SPMDAxisContext) and
axis_context.manual_axes and
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
if axis_env.sizes[axis_pos] == 1:
return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32)))
x = hlo.iota(ir.RankedTensorType.get(
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
sharding_proto = (

View File

@ -2162,7 +2162,22 @@ class ShardMapTest(jtu.JaxTestCase):
mesh, in_specs=P('i', None), out_specs=P('i', None),
check_rep=False, auto=frozenset({'j'}))()
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1))
def test_partial_auto_axis_index_degenerated_axis(self):
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')
mesh = jtu.create_mesh((1, 2), ('i', 'j'))
out_sharding = NamedSharding(mesh, P('i', None))
@partial(jax.jit, out_shardings=out_sharding)
def f():
return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1),
mesh, in_specs=P('i', None), out_specs=P('i', None),
check_rep=False, auto=frozenset({'j'}))()
self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1))
def test_partial_auto_ppermute(self):
if xla_extension_version < 302: