Expose channel_id in AllToAllOp in both XLA builder and MHLO.

PiperOrigin-RevId: 494334791
This commit is contained in:
Anselm Levskaya 2022-12-09 21:57:46 -08:00 committed by jax authors
parent 0a2d1cd45e
commit ffb4711969

View File

@ -944,12 +944,25 @@ def _all_to_all_lowering(ctx, x, *,
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
operand = [x] if mlir_api_version >= 38 else x
is_spmd = isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext))
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=mhlo.ChannelHandle.get(channel,
mlir.DEVICE_TO_DEVICE_TYPE))
else:
other_args = {}
return mhlo.AllToAllOp(
operand,
split_dimension=mlir.i64_attr(split_axis),
concat_dimension=mlir.i64_attr(concat_axis),
split_count=mlir.i64_attr(split_count),
replica_groups=_replica_groups_mhlo(replica_groups)).results
replica_groups=_replica_groups_mhlo(replica_groups),
**other_args).results
else:
warnings.warn(
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "