Merge pull request #19080 from Zantares:tenglu/fix_ut

PiperOrigin-RevId: 597723598
This commit is contained in:
jax authors 2024-01-11 20:26:31 -08:00
commit 761cf8ba7d

View File

@ -2305,20 +2305,27 @@ class LaxControlFlowTest(jtu.JaxTestCase):
@ignore_jit_of_pmap_warning()
def test_while_loop_of_pmap(self):
# code from jsnoek@
# Avoid accuracy issue caused by too many devices.
DEVICE_LIMITATION = 4
devices = jax.devices()
count = jax.device_count()
if jax.device_count() >= DEVICE_LIMITATION:
devices = devices[:DEVICE_LIMITATION]
count = DEVICE_LIMITATION
# code from jsnoek@
def body(i, x):
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), axis_name='i')(x)
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), devices=devices, axis_name='i')(x)
return result + x
f_loop = lambda x: lax.fori_loop(0, 3, body, x) # noqa: F821
ans = f_loop(jnp.ones(jax.device_count()))
ans = f_loop(jnp.ones(count))
del body, f_loop
def body2(i, x):
result = jnp.broadcast_to(jnp.sin(x).sum(), x.shape)
return result + x
g_loop = lambda x: lax.fori_loop(0, 3, body2, x)
expected = g_loop(jnp.ones(jax.device_count()))
expected = g_loop(jnp.ones(count))
self.assertAllClose(ans, expected, check_dtypes=False)