mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Add a regression test for a pmap issue that is fixed at head.
Fixes https://github.com/google/jax/issues/5757 PiperOrigin-RevId: 580243825
This commit is contained in:
parent
e66f4e94c4
commit
f4eb3f6d86
@ -625,6 +625,15 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
(concat_axis + 1, 0))
|
||||
self.assertAllClose(y, ref)
|
||||
|
||||
def testNestedPmapAxisSwap(self):
|
||||
# Regression test for https://github.com/google/jax/issues/5757
|
||||
if jax.device_count() < 8:
|
||||
raise SkipTest("test requires at least 8 devices")
|
||||
f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0,
|
||||
out_axes=0)
|
||||
A = jnp.ones((2, 4, 3))
|
||||
self.assertAllClose(A.transpose((0, 2, 1)), f(A))
|
||||
|
||||
def testNestedBasic(self):
|
||||
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
|
||||
f = self.pmap(self.pmap(f, 'i'), 'j')
|
||||
|
Loading…
x
Reference in New Issue
Block a user