Implement a trivial ppermute collective batcher

Splitting a single-dimensional ppermute into multiple permutations is a
hard problem in general, but not when we're splitting a size-1
dimension. More importantly, this is the case that's triggered by any
`xmap` of a `ppermute`, so we better have an implementation ready!
This commit is contained in:
Adam Paszke 2021-01-28 18:19:36 +00:00
parent 0263663940
commit a7f9b84bf1
2 changed files with 50 additions and 24 deletions

View File

@ -652,9 +652,16 @@ def _ppermute_transpose_rule(t, x, perm, axis_name):
return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]
def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
assert len(perm) == frame.size, "Permutation doesn't match the axis size!"
assert axis_name == frame.name, "ppermute batcher called with wrong axis name"
(v,), (d,) = vals_in, dims_in
if not isinstance(axis_name, (tuple, list)):
axis_name = (axis_name,)
remaining_axes = tuple(axis for axis in axis_name if axis != frame.name)
if frame.size == 1 and remaining_axes:
return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d
if remaining_axes:
raise NotImplementedError("ppermute batcher only supports a single axis")
assert axis_name[0] == frame.name, "ppermute batcher called with a wrong axis!"
assert len(perm) == frame.size, "Permutation doesn't match the axis size!"
assert d is not batching.not_mapped
perm_indices = [None] * frame.size
for src, dst in perm:

View File

@ -112,25 +112,39 @@ class XMapTest(jtu.JaxTestCase):
self.assertAllClose(d, b * 4)
@ignore_xmap_warning()
def testBasicCollective(self):
local_devices = list(jax.local_devices())
if len(local_devices) < 4:
raise SkipTest("Test requires at least 4 local devices")
def f(a, b):
return lax.psum(a * 2, 'a'), b * 4
devices = np.array(local_devices[:4]).reshape((2, 2))
with mesh(devices, ('x', 'y')):
fm = xmap(f,
in_axes=[['a', 'b', ...], {0: 'c'}],
out_axes=[['b', ...], {0: 'c'}],
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
c, d = fm(a, b)
self.assertAllClose(c, (a * 2).sum(0))
self.assertAllClose(d, b * 4)
@with_mesh([('x', 2), ('y', 2)])
def testCollectiveReduce(self):
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
in_axes=[['a', 'b', ...], {0: 'c'}],
out_axes=[['b', ...], {0: 'c'}],
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
c, d = fm(a, b)
self.assertAllClose(c, (a * 2).sum(0))
self.assertAllClose(d, b * 4)
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testCollectivePermute2D(self):
perm = np.array([3, 1, 2, 0])
x = jnp.arange(4).reshape((2, 2))
result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm),
in_axes=['i', 'j', ...],
out_axes=['i', 'j', ...],
axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,))
self.assertAllClose(result, perm)
@ignore_xmap_warning()
def testCollectivePermute1D(self):
perm = np.array([3, 1, 2, 0])
x = jnp.arange(4)
result = xmap(lambda x: lax.pshuffle(x, 'i', perm),
in_axes=['i', ...],
out_axes=['i', ...])(x)
self.assertAllClose(result, perm)
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
@ -323,13 +337,18 @@ class XMapTest(jtu.JaxTestCase):
class XMapTestSPMD(XMapTest):
"""Re-executes all tests with the SPMD partitioner enabled"""
skipped_tests = {
"NestedMesh", # Nesting xmap calls is not supported in the SPMD lowering yet
"CollectivePermute2D" # vmap of multidimensional permute not implemented yet
}
def setUp(self):
super().setUp()
if jtu.device_under_test() != "tpu":
raise SkipTest
# Nesting xmap calls is not supported in the SPMD lowering yet
if "NestedMesh" in self._testMethodName:
raise SkipTest
for skipped_name in self.skipped_tests:
if skipped_name in self._testMethodName:
raise SkipTest
jax.experimental.maps.make_xmap_callable.cache_clear()
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True