mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
always lower all_to_all to AllToAll
This commit is contained in:
parent
cf2f182a6c
commit
6ea8a546f6
@ -20,7 +20,6 @@ import itertools
|
||||
import math
|
||||
import string
|
||||
from typing import Sequence, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -960,16 +959,6 @@ def _index_in_group(axis_name, axis_index_groups):
|
||||
return lax.squeeze(
|
||||
slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0])
|
||||
|
||||
def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis, axis_index_groups):
|
||||
idx = _index_in_group(axis_name, axis_index_groups)
|
||||
full = all_gather(x, axis_name, axis_index_groups=axis_index_groups)
|
||||
axis_size = full.shape[0]
|
||||
tile_size = x.shape[split_axis] // axis_size
|
||||
tile_base_idx = idx * tile_size
|
||||
sliced = slicing.dynamic_slice_in_dim(full, tile_base_idx, tile_size,
|
||||
split_axis + 1)
|
||||
return _foldaxis(concat_axis, _moveaxis(0, concat_axis, sliced))
|
||||
|
||||
|
||||
def _all_to_all_lowering(ctx, x, *,
|
||||
split_axis, concat_axis, axis_name, axis_index_groups):
|
||||
@ -978,46 +967,29 @@ def _all_to_all_lowering(ctx, x, *,
|
||||
axis_index_groups)
|
||||
if len(replica_groups[0]) == 1:
|
||||
return [x]
|
||||
elif ((ctx.module_context.platform == "tpu") or
|
||||
((ctx.module_context.platform in ("cuda", "rocm"))
|
||||
and (split_axis == 0) and (concat_axis == 0))):
|
||||
split_count = len(replica_groups[0])
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
is_spmd = isinstance(
|
||||
ctx.module_context.axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
# We want to emit the all-gather with global device IDs and a unique
|
||||
# channel ID, as otherwise it interprets the devices as replicas instead
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=hlo.ChannelHandle.get(channel,
|
||||
mlir.DEVICE_TO_DEVICE_TYPE))
|
||||
else:
|
||||
other_args = {}
|
||||
return hlo.AllToAllOp(
|
||||
x,
|
||||
split_dimension=mlir.i64_attr(split_axis),
|
||||
concat_dimension=mlir.i64_attr(concat_axis),
|
||||
split_count=mlir.i64_attr(split_count),
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args).results
|
||||
split_count = len(replica_groups[0])
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
is_spmd = isinstance(
|
||||
ctx.module_context.axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
# We want to emit the all-gather with global device IDs and a unique
|
||||
# channel ID, as otherwise it interprets the devices as replicas instead
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
|
||||
other_args = dict(channel_handle=channel_handle)
|
||||
else:
|
||||
warnings.warn(
|
||||
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "
|
||||
"split_axis and concat_axis are both 0). All other backends emulate it using a "
|
||||
"very slow and memory intensive algorithm, so expect significant slowdowns."
|
||||
)
|
||||
lowering = mlir.lower_fun(_all_to_all_via_all_gather,
|
||||
multiple_results=False)
|
||||
return lowering(ctx, x,
|
||||
axis_name=axis_name,
|
||||
split_axis=split_axis,
|
||||
concat_axis=concat_axis,
|
||||
axis_index_groups=axis_index_groups)
|
||||
other_args = {}
|
||||
return hlo.AllToAllOp(
|
||||
x,
|
||||
split_dimension=mlir.i64_attr(split_axis),
|
||||
concat_dimension=mlir.i64_attr(concat_axis),
|
||||
split_count=mlir.i64_attr(split_count),
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args).results
|
||||
|
||||
def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
|
||||
return (all_to_all(
|
||||
|
@ -110,9 +110,6 @@ def tearDownModule():
|
||||
ignore_jit_of_pmap_warning = partial(
|
||||
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
||||
|
||||
ignore_slow_all_to_all_warning = partial(
|
||||
jtu.ignore_warning, message="all_to_all.*expect significant slowdowns.*")
|
||||
|
||||
ignore_xmap_warning = partial(
|
||||
jtu.ignore_warning, message=".*is an experimental.*")
|
||||
|
||||
@ -515,7 +512,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
actual, expected[i // 2 * scatter_len:(i // 2 + 1) * scatter_len])
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testTrees(self):
|
||||
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
|
||||
def protate(x, axis_name):
|
||||
@ -572,7 +568,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis}
|
||||
for split_axis, concat_axis in it.product(range(2), range(2)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAll(self, split_axis, concat_axis):
|
||||
pmap_in_axis = 0
|
||||
shape = (jax.device_count(),) * 3
|
||||
@ -592,7 +587,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis}
|
||||
for split_axis, concat_axis in it.product(range(2), range(2)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllSplitAxis(self, split_axis, concat_axis):
|
||||
if jax.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
@ -1030,7 +1024,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
] for name, prim in
|
||||
(('Gather', lax.all_gather), ('ReduceScatter', lax.psum_scatter))
|
||||
))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOf(self, prim, tiled, use_axis_index_groups):
|
||||
if jtu.device_under_test() == "gpu":
|
||||
raise SkipTest("XLA:GPU with ReduceScatter deadlocks") # b/264516146
|
||||
@ -1549,7 +1542,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected_bz1, bz1, check_dtypes=False)
|
||||
self.assertAllClose(bz2, bz2, check_dtypes=False)
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testPswapaxes(self):
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 3, device_count, 5)
|
||||
@ -1559,7 +1551,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
expected = np.swapaxes(x, 0, 2)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOfPswapaxes(self):
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 1, device_count)
|
||||
@ -1575,7 +1566,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
expected = np.tile(w, reps=device_count).reshape(shape)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllReplicaGroups(self):
|
||||
# If num_devices = 4, these would be the inputs/outputs:
|
||||
# input = [[0, 1], [2, 3], [4, 5], [6, 7]]
|
||||
@ -1604,7 +1594,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
0, 2).reshape(shape)
|
||||
self.assertAllClose(fn(x), expected, check_dtypes=False)
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOfAllToAllReplicaGroups(self):
|
||||
device_count = jax.device_count()
|
||||
if device_count % 2 != 0:
|
||||
@ -2322,7 +2311,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
||||
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllInVmap(self, split_axis, concat_axis, vmap_axis):
|
||||
def f(x):
|
||||
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
|
||||
@ -2391,7 +2379,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis}
|
||||
for split_axis, concat_axis in it.product(range(3), range(3)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllVsVmap(self, split_axis, concat_axis):
|
||||
def f(x):
|
||||
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
|
||||
@ -2406,7 +2393,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
"axes": axes, "split_axis": split_axis, "concat_axis": concat_axis}
|
||||
for axes, split_axis, concat_axis
|
||||
in it.product([('i', 'j'), ('j', 'i')], range(3), range(3)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
@unittest.skip("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
|
||||
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
|
||||
if jax.device_count() < 4:
|
||||
|
@ -170,7 +170,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
c = fwd(a)
|
||||
self.assertAllClose(c[1, :], a[0, :])
|
||||
|
||||
@jtu.skip_on_devices("cpu") # all_to_all has a warning on cpu
|
||||
def test_all_to_all(self):
|
||||
devices = np.array(jax.devices())
|
||||
mesh = Mesh(devices, axis_names=('x'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user