Add support for MANUAL lowering of ppermute

PiperOrigin-RevId: 481157480
This commit is contained in:
Adam Paszke 2022-10-14 08:59:05 -07:00 committed by jax authors
parent c848efa11b
commit 746dd5ab13
2 changed files with 22 additions and 1 deletions

View File

@ -843,7 +843,18 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
full_perm[i, j, 0] = grp[src]
full_perm[i, j, 1] = grp[dst]
full_perm = full_perm.reshape((-1, 2))
return mhlo.CollectivePermuteOp(x, mlir.dense_int_elements(full_perm)).results
axis_context = ctx.module_context.axis_context
is_manual = isinstance(axis_context, mlir.SPMDAxisContext) and axis_context.manual_axes
if is_manual:
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=mhlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE))
else:
other_args = {}
return mhlo.CollectivePermuteOp(
x, mlir.dense_int_elements(full_perm), **other_args).results
def _ppermute_transpose_rule(t, x, perm, axis_name):
srcs, dsts = unzip2(perm)

View File

@ -873,6 +873,16 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
in_axes=['i', ...], out_axes=[...], axis_resources={'i': 'x'})(x)
self.assertAllClose(x.sum(0), y)
@jtu.with_mesh([('x', 2)])
def testRepro(self):
n = 2
x = jnp.arange(n * 5, dtype=jnp.float32).reshape(n, 5)
f = xmap(lambda x: lax.ppermute(x, 'i', perm=[(j, (j + 1) % n) for j in range(n)]),
in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
g = pjit(f, in_axis_resources=P('x'), out_axis_resources=P('x'))
self.assertAllClose(g(x), x[::-1])
class NamedNumPyTest(XMapTestCase):