Add a regression test for a pmap issue that is fixed at head.

Fixes https://github.com/google/jax/issues/5757

PiperOrigin-RevId: 580243825
This commit is contained in:
Peter Hawkins 2023-11-07 11:20:38 -08:00 committed by jax authors
parent e66f4e94c4
commit f4eb3f6d86

View File

@ -625,6 +625,15 @@ class PythonPmapTest(jtu.JaxTestCase):
(concat_axis + 1, 0))
self.assertAllClose(y, ref)
def testNestedPmapAxisSwap(self):
# Regression test for https://github.com/google/jax/issues/5757
if jax.device_count() < 8:
raise SkipTest("test requires at least 8 devices")
f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0,
out_axes=0)
A = jnp.ones((2, 4, 3))
self.assertAllClose(A.transpose((0, 2, 1)), f(A))
def testNestedBasic(self):
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
f = self.pmap(self.pmap(f, 'i'), 'j')