Support axis_index inside shard_map(auto=...) by using iota and

then calling full_to_shard.

PiperOrigin-RevId: 705704369
This commit is contained in:
Parker Schuh 2024-12-12 18:38:27 -08:00 committed by jax authors
parent 1453a222d4
commit 0e7f218eb0
2 changed files with 35 additions and 6 deletions

View File

@ -1591,7 +1591,25 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
axis_name, = axis_name
if axis_name not in axis_env.names:
raise NameError(f"unbound axis name: {axis_name}")
axis_context = ctx.module_context.axis_context
axis_pos = list(axis_env.names).index(axis_name)
# For partial auto, lower using iota.
if (isinstance(axis_context, SPMDAxisContext) and
axis_context.manual_axes and
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
x = hlo.iota(ir.RankedTensorType.get(
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
sharding_proto = (
NamedSharding(axis_context.mesh, P(axis_name))
._to_xla_hlo_sharding(1).to_proto())
aval_in = ShapedArray((axis_env.sizes[axis_pos],), np.int32)
aval_out = ShapedArray((1,), np.int32)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto)
proto = pxla.manual_proto(aval_in, axis_context.manual_axes, axis_context.mesh)
x = mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto)
return hlo.reshape(ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), x)
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
div = mlir.ir_constant(
np.array(
@ -1599,12 +1617,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
)
)
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
if isinstance(axis_context, (ShardingContext, SPMDAxisContext)):
device_id = hlo.partition_id()
else:
device_id = hlo.replica_id()

View File

@ -2147,6 +2147,22 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
def test_partial_auto_axis_index(self):
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
out_sharding = NamedSharding(mesh, P('i', None))
@partial(jax.jit, out_shardings=out_sharding)
def f():
return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1),
mesh, in_specs=P('i', None), out_specs=P('i', None),
check_rep=False, auto=frozenset({'j'}))()
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
# https://github.com/jax-ml/jax/pull/21032
mesh = jtu.create_mesh((4, 2), ('i', 'j'))