mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
fix the degenerated case
This commit is contained in:
parent
e87a2a5929
commit
3ff000ee3e
@ -1634,6 +1634,8 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
if (isinstance(axis_context, SPMDAxisContext) and
|
||||
axis_context.manual_axes and
|
||||
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
|
||||
if axis_env.sizes[axis_pos] == 1:
|
||||
return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32)))
|
||||
x = hlo.iota(ir.RankedTensorType.get(
|
||||
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
|
||||
sharding_proto = (
|
||||
|
@ -2162,7 +2162,22 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh, in_specs=P('i', None), out_specs=P('i', None),
|
||||
check_rep=False, auto=frozenset({'j'}))()
|
||||
|
||||
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
|
||||
self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1))
|
||||
|
||||
def test_partial_auto_axis_index_degenerated_axis(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest('Shardy does not support full-to-shard.')
|
||||
|
||||
mesh = jtu.create_mesh((1, 2), ('i', 'j'))
|
||||
out_sharding = NamedSharding(mesh, P('i', None))
|
||||
|
||||
@partial(jax.jit, out_shardings=out_sharding)
|
||||
def f():
|
||||
return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1),
|
||||
mesh, in_specs=P('i', None), out_specs=P('i', None),
|
||||
check_rep=False, auto=frozenset({'j'}))()
|
||||
|
||||
self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1))
|
||||
|
||||
def test_partial_auto_ppermute(self):
|
||||
if xla_extension_version < 302:
|
||||
|
Loading…
x
Reference in New Issue
Block a user