mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Expose channel_id in AllToAllOp in both XLA builder and MHLO.
PiperOrigin-RevId: 494334791
This commit is contained in:
parent
0a2d1cd45e
commit
ffb4711969
@ -944,12 +944,25 @@ def _all_to_all_lowering(ctx, x, *,
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
operand = [x] if mlir_api_version >= 38 else x
|
||||
is_spmd = isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext))
|
||||
if is_spmd:
|
||||
# We want to emit the all-gather with global device IDs and a unique
|
||||
# channel ID, as otherwise it interprets the devices as replicas instead
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=mhlo.ChannelHandle.get(channel,
|
||||
mlir.DEVICE_TO_DEVICE_TYPE))
|
||||
else:
|
||||
other_args = {}
|
||||
return mhlo.AllToAllOp(
|
||||
operand,
|
||||
split_dimension=mlir.i64_attr(split_axis),
|
||||
concat_dimension=mlir.i64_attr(concat_axis),
|
||||
split_count=mlir.i64_attr(split_count),
|
||||
replica_groups=_replica_groups_mhlo(replica_groups)).results
|
||||
replica_groups=_replica_groups_mhlo(replica_groups),
|
||||
**other_args).results
|
||||
else:
|
||||
warnings.warn(
|
||||
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "
|
||||
|
Loading…
x
Reference in New Issue
Block a user