always lower all_to_all to AllToAll

This commit is contained in:
Matthew Johnson 2023-04-11 18:21:25 -07:00
parent cf2f182a6c
commit 6ea8a546f6
3 changed files with 22 additions and 65 deletions

View File

@ -20,7 +20,6 @@ import itertools
import math
import string
from typing import Sequence, Union
import warnings
import numpy as np
@ -960,16 +959,6 @@ def _index_in_group(axis_name, axis_index_groups):
return lax.squeeze(
slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0])
def _all_to_all_via_all_gather(x, *, axis_name, split_axis, concat_axis, axis_index_groups):
idx = _index_in_group(axis_name, axis_index_groups)
full = all_gather(x, axis_name, axis_index_groups=axis_index_groups)
axis_size = full.shape[0]
tile_size = x.shape[split_axis] // axis_size
tile_base_idx = idx * tile_size
sliced = slicing.dynamic_slice_in_dim(full, tile_base_idx, tile_size,
split_axis + 1)
return _foldaxis(concat_axis, _moveaxis(0, concat_axis, sliced))
def _all_to_all_lowering(ctx, x, *,
split_axis, concat_axis, axis_name, axis_index_groups):
@ -978,46 +967,29 @@ def _all_to_all_lowering(ctx, x, *,
axis_index_groups)
if len(replica_groups[0]) == 1:
return [x]
elif ((ctx.module_context.platform == "tpu") or
((ctx.module_context.platform in ("cuda", "rocm"))
and (split_axis == 0) and (concat_axis == 0))):
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
is_spmd = isinstance(
ctx.module_context.axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(channel,
mlir.DEVICE_TO_DEVICE_TYPE))
else:
other_args = {}
return hlo.AllToAllOp(
x,
split_dimension=mlir.i64_attr(split_axis),
concat_dimension=mlir.i64_attr(concat_axis),
split_count=mlir.i64_attr(split_count),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args).results
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
is_spmd = isinstance(
ctx.module_context.axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
other_args = dict(channel_handle=channel_handle)
else:
warnings.warn(
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "
"split_axis and concat_axis are both 0). All other backends emulate it using a "
"very slow and memory intensive algorithm, so expect significant slowdowns."
)
lowering = mlir.lower_fun(_all_to_all_via_all_gather,
multiple_results=False)
return lowering(ctx, x,
axis_name=axis_name,
split_axis=split_axis,
concat_axis=concat_axis,
axis_index_groups=axis_index_groups)
other_args = {}
return hlo.AllToAllOp(
x,
split_dimension=mlir.i64_attr(split_axis),
concat_dimension=mlir.i64_attr(concat_axis),
split_count=mlir.i64_attr(split_count),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args).results
def _all_to_all_transpose_rule(cts, x, axis_name, split_axis, concat_axis, axis_index_groups):
return (all_to_all(

View File

@ -110,9 +110,6 @@ def tearDownModule():
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning, message=".*jit-of-pmap.*")
ignore_slow_all_to_all_warning = partial(
jtu.ignore_warning, message="all_to_all.*expect significant slowdowns.*")
ignore_xmap_warning = partial(
jtu.ignore_warning, message=".*is an experimental.*")
@ -515,7 +512,6 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(
actual, expected[i // 2 * scatter_len:(i // 2 + 1) * scatter_len])
@ignore_slow_all_to_all_warning()
def testTrees(self):
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
def protate(x, axis_name):
@ -572,7 +568,6 @@ class PythonPmapTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(2), range(2)))
@ignore_slow_all_to_all_warning()
def testAllToAll(self, split_axis, concat_axis):
pmap_in_axis = 0
shape = (jax.device_count(),) * 3
@ -592,7 +587,6 @@ class PythonPmapTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(2), range(2)))
@ignore_slow_all_to_all_warning()
def testAllToAllSplitAxis(self, split_axis, concat_axis):
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
@ -1030,7 +1024,6 @@ class PythonPmapTest(jtu.JaxTestCase):
] for name, prim in
(('Gather', lax.all_gather), ('ReduceScatter', lax.psum_scatter))
))
@ignore_slow_all_to_all_warning()
def testGradOf(self, prim, tiled, use_axis_index_groups):
if jtu.device_under_test() == "gpu":
raise SkipTest("XLA:GPU with ReduceScatter deadlocks") # b/264516146
@ -1549,7 +1542,6 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(expected_bz1, bz1, check_dtypes=False)
self.assertAllClose(bz2, bz2, check_dtypes=False)
@ignore_slow_all_to_all_warning()
def testPswapaxes(self):
device_count = jax.device_count()
shape = (device_count, 3, device_count, 5)
@ -1559,7 +1551,6 @@ class PythonPmapTest(jtu.JaxTestCase):
expected = np.swapaxes(x, 0, 2)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_slow_all_to_all_warning()
def testGradOfPswapaxes(self):
device_count = jax.device_count()
shape = (device_count, 1, device_count)
@ -1575,7 +1566,6 @@ class PythonPmapTest(jtu.JaxTestCase):
expected = np.tile(w, reps=device_count).reshape(shape)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_slow_all_to_all_warning()
def testAllToAllReplicaGroups(self):
# If num_devices = 4, these would be the inputs/outputs:
# input = [[0, 1], [2, 3], [4, 5], [6, 7]]
@ -1604,7 +1594,6 @@ class PythonPmapTest(jtu.JaxTestCase):
0, 2).reshape(shape)
self.assertAllClose(fn(x), expected, check_dtypes=False)
@ignore_slow_all_to_all_warning()
def testGradOfAllToAllReplicaGroups(self):
device_count = jax.device_count()
if device_count % 2 != 0:
@ -2322,7 +2311,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
@ignore_slow_all_to_all_warning()
def testAllToAllInVmap(self, split_axis, concat_axis, vmap_axis):
def f(x):
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
@ -2391,7 +2379,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(3), range(3)))
@ignore_slow_all_to_all_warning()
def testAllToAllVsVmap(self, split_axis, concat_axis):
def f(x):
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
@ -2406,7 +2393,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
"axes": axes, "split_axis": split_axis, "concat_axis": concat_axis}
for axes, split_axis, concat_axis
in it.product([('i', 'j'), ('j', 'i')], range(3), range(3)))
@ignore_slow_all_to_all_warning()
@unittest.skip("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
if jax.device_count() < 4:

View File

@ -170,7 +170,6 @@ class ShardMapTest(jtu.JaxTestCase):
c = fwd(a)
self.assertAllClose(c[1, :], a[0, :])
@jtu.skip_on_devices("cpu") # all_to_all has a warning on cpu
def test_all_to_all(self):
devices = np.array(jax.devices())
mesh = Mesh(devices, axis_names=('x'))