mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Support axis_index inside shard_map(auto=...) by using iota and
then calling full_to_shard. PiperOrigin-RevId: 705704369
This commit is contained in:
parent
1453a222d4
commit
0e7f218eb0
@ -1591,7 +1591,25 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
axis_name, = axis_name
|
||||
if axis_name not in axis_env.names:
|
||||
raise NameError(f"unbound axis name: {axis_name}")
|
||||
axis_context = ctx.module_context.axis_context
|
||||
axis_pos = list(axis_env.names).index(axis_name)
|
||||
|
||||
# For partial auto, lower using iota.
|
||||
if (isinstance(axis_context, SPMDAxisContext) and
|
||||
axis_context.manual_axes and
|
||||
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
|
||||
x = hlo.iota(ir.RankedTensorType.get(
|
||||
[axis_env.sizes[axis_pos]], ir.IntegerType.get_signless(32)), mlir.i64_attr(0))
|
||||
sharding_proto = (
|
||||
NamedSharding(axis_context.mesh, P(axis_name))
|
||||
._to_xla_hlo_sharding(1).to_proto())
|
||||
aval_in = ShapedArray((axis_env.sizes[axis_pos],), np.int32)
|
||||
aval_out = ShapedArray((1,), np.int32)
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto)
|
||||
proto = pxla.manual_proto(aval_in, axis_context.manual_axes, axis_context.mesh)
|
||||
x = mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto)
|
||||
return hlo.reshape(ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), x)
|
||||
|
||||
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
|
||||
div = mlir.ir_constant(
|
||||
np.array(
|
||||
@ -1599,12 +1617,7 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
)
|
||||
)
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(SPMDAxisContext, ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
if isinstance(axis_context, (ShardingContext, SPMDAxisContext)):
|
||||
device_id = hlo.partition_id()
|
||||
else:
|
||||
device_id = hlo.replica_id()
|
||||
|
@ -2147,6 +2147,22 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
|
||||
|
||||
def test_partial_auto_axis_index(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest('Shardy does not support full-to-shard.')
|
||||
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
out_sharding = NamedSharding(mesh, P('i', None))
|
||||
|
||||
@partial(jax.jit, out_shardings=out_sharding)
|
||||
def f():
|
||||
return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1),
|
||||
mesh, in_specs=P('i', None), out_specs=P('i', None),
|
||||
check_rep=False, auto=frozenset({'j'}))()
|
||||
|
||||
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
|
||||
|
||||
|
||||
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
|
||||
# https://github.com/jax-ml/jax/pull/21032
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user