diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index e0b608d7c..a139ce9b5 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2305,20 +2305,27 @@ class LaxControlFlowTest(jtu.JaxTestCase): @ignore_jit_of_pmap_warning() def test_while_loop_of_pmap(self): - # code from jsnoek@ + # Avoid accuracy issue caused by too many devices. + DEVICE_LIMITATION = 4 + devices = jax.devices() + count = jax.device_count() + if jax.device_count() >= DEVICE_LIMITATION: + devices = devices[:DEVICE_LIMITATION] + count = DEVICE_LIMITATION + # code from jsnoek@ def body(i, x): - result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), axis_name='i')(x) + result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), devices=devices, axis_name='i')(x) return result + x f_loop = lambda x: lax.fori_loop(0, 3, body, x) # noqa: F821 - ans = f_loop(jnp.ones(jax.device_count())) + ans = f_loop(jnp.ones(count)) del body, f_loop def body2(i, x): result = jnp.broadcast_to(jnp.sin(x).sum(), x.shape) return result + x g_loop = lambda x: lax.fori_loop(0, 3, body2, x) - expected = g_loop(jnp.ones(jax.device_count())) + expected = g_loop(jnp.ones(count)) self.assertAllClose(ans, expected, check_dtypes=False)