diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 2ae99c850..abe56fd07 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -625,6 +625,15 @@ class PythonPmapTest(jtu.JaxTestCase): (concat_axis + 1, 0)) self.assertAllClose(y, ref) + def testNestedPmapAxisSwap(self): + # Regression test for https://github.com/google/jax/issues/5757 + if jax.device_count() < 8: + raise SkipTest("test requires at least 8 devices") + f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0, + out_axes=0) + A = jnp.ones((2, 4, 3)) + self.assertAllClose(A.transpose((0, 2, 1)), f(A)) + def testNestedBasic(self): f = lambda x: lax.psum(lax.psum(x, 'i'), 'j') f = self.pmap(self.pmap(f, 'i'), 'j')