mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement a trivial ppermute collective batcher
Splitting a single-dimensional ppermute into multiple permutations is a hard problem in general, but not when we're splitting a size-1 dimension. More importantly, this is the case that's triggered by any `xmap` of a `ppermute`, so we better have an implementation ready!
This commit is contained in:
parent
0263663940
commit
a7f9b84bf1
@ -652,9 +652,16 @@ def _ppermute_transpose_rule(t, x, perm, axis_name):
|
||||
return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]
|
||||
|
||||
def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm):
|
||||
assert len(perm) == frame.size, "Permutation doesn't match the axis size!"
|
||||
assert axis_name == frame.name, "ppermute batcher called with wrong axis name"
|
||||
(v,), (d,) = vals_in, dims_in
|
||||
if not isinstance(axis_name, (tuple, list)):
|
||||
axis_name = (axis_name,)
|
||||
remaining_axes = tuple(axis for axis in axis_name if axis != frame.name)
|
||||
if frame.size == 1 and remaining_axes:
|
||||
return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d
|
||||
if remaining_axes:
|
||||
raise NotImplementedError("ppermute batcher only supports a single axis")
|
||||
assert axis_name[0] == frame.name, "ppermute batcher called with a wrong axis!"
|
||||
assert len(perm) == frame.size, "Permutation doesn't match the axis size!"
|
||||
assert d is not batching.not_mapped
|
||||
perm_indices = [None] * frame.size
|
||||
for src, dst in perm:
|
||||
|
@ -112,25 +112,39 @@ class XMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(d, b * 4)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testBasicCollective(self):
|
||||
local_devices = list(jax.local_devices())
|
||||
if len(local_devices) < 4:
|
||||
raise SkipTest("Test requires at least 4 local devices")
|
||||
def f(a, b):
|
||||
return lax.psum(a * 2, 'a'), b * 4
|
||||
devices = np.array(local_devices[:4]).reshape((2, 2))
|
||||
with mesh(devices, ('x', 'y')):
|
||||
fm = xmap(f,
|
||||
in_axes=[['a', 'b', ...], {0: 'c'}],
|
||||
out_axes=[['b', ...], {0: 'c'}],
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
ashape = (16, 8, 5)
|
||||
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
||||
c, d = fm(a, b)
|
||||
self.assertAllClose(c, (a * 2).sum(0))
|
||||
self.assertAllClose(d, b * 4)
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
def testCollectiveReduce(self):
|
||||
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
|
||||
in_axes=[['a', 'b', ...], {0: 'c'}],
|
||||
out_axes=[['b', ...], {0: 'c'}],
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
ashape = (16, 8, 5)
|
||||
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
||||
c, d = fm(a, b)
|
||||
self.assertAllClose(c, (a * 2).sum(0))
|
||||
self.assertAllClose(d, b * 4)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
def testCollectivePermute2D(self):
|
||||
perm = np.array([3, 1, 2, 0])
|
||||
x = jnp.arange(4).reshape((2, 2))
|
||||
result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm),
|
||||
in_axes=['i', 'j', ...],
|
||||
out_axes=['i', 'j', ...],
|
||||
axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,))
|
||||
self.assertAllClose(result, perm)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testCollectivePermute1D(self):
|
||||
perm = np.array([3, 1, 2, 0])
|
||||
x = jnp.arange(4)
|
||||
result = xmap(lambda x: lax.pshuffle(x, 'i', perm),
|
||||
in_axes=['i', ...],
|
||||
out_axes=['i', ...])(x)
|
||||
self.assertAllClose(result, perm)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2), ('y', 2)])
|
||||
@ -323,13 +337,18 @@ class XMapTest(jtu.JaxTestCase):
|
||||
class XMapTestSPMD(XMapTest):
|
||||
"""Re-executes all tests with the SPMD partitioner enabled"""
|
||||
|
||||
skipped_tests = {
|
||||
"NestedMesh", # Nesting xmap calls is not supported in the SPMD lowering yet
|
||||
"CollectivePermute2D" # vmap of multidimensional permute not implemented yet
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest
|
||||
# Nesting xmap calls is not supported in the SPMD lowering yet
|
||||
if "NestedMesh" in self._testMethodName:
|
||||
raise SkipTest
|
||||
for skipped_name in self.skipped_tests:
|
||||
if skipped_name in self._testMethodName:
|
||||
raise SkipTest
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user