mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for MANUAL lowering of ppermute
PiperOrigin-RevId: 481157480
This commit is contained in:
parent
c848efa11b
commit
746dd5ab13
@ -843,7 +843,18 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
|
||||
full_perm[i, j, 0] = grp[src]
|
||||
full_perm[i, j, 1] = grp[dst]
|
||||
full_perm = full_perm.reshape((-1, 2))
|
||||
return mhlo.CollectivePermuteOp(x, mlir.dense_int_elements(full_perm)).results
|
||||
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_manual = isinstance(axis_context, mlir.SPMDAxisContext) and axis_context.manual_axes
|
||||
if is_manual:
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=mhlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE))
|
||||
else:
|
||||
other_args = {}
|
||||
|
||||
return mhlo.CollectivePermuteOp(
|
||||
x, mlir.dense_int_elements(full_perm), **other_args).results
|
||||
|
||||
def _ppermute_transpose_rule(t, x, perm, axis_name):
|
||||
srcs, dsts = unzip2(perm)
|
||||
|
@ -873,6 +873,16 @@ class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
|
||||
in_axes=['i', ...], out_axes=[...], axis_resources={'i': 'x'})(x)
|
||||
self.assertAllClose(x.sum(0), y)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testRepro(self):
|
||||
n = 2
|
||||
x = jnp.arange(n * 5, dtype=jnp.float32).reshape(n, 5)
|
||||
|
||||
f = xmap(lambda x: lax.ppermute(x, 'i', perm=[(j, (j + 1) % n) for j in range(n)]),
|
||||
in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
|
||||
g = pjit(f, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
self.assertAllClose(g(x), x[::-1])
|
||||
|
||||
|
||||
class NamedNumPyTest(XMapTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user