mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #19080 from Zantares:tenglu/fix_ut
PiperOrigin-RevId: 597723598
This commit is contained in:
commit
761cf8ba7d
@ -2305,20 +2305,27 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_jit_of_pmap_warning()
|
||||
def test_while_loop_of_pmap(self):
|
||||
# code from jsnoek@
|
||||
# Avoid accuracy issue caused by too many devices.
|
||||
DEVICE_LIMITATION = 4
|
||||
devices = jax.devices()
|
||||
count = jax.device_count()
|
||||
if jax.device_count() >= DEVICE_LIMITATION:
|
||||
devices = devices[:DEVICE_LIMITATION]
|
||||
count = DEVICE_LIMITATION
|
||||
|
||||
# code from jsnoek@
|
||||
def body(i, x):
|
||||
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), axis_name='i')(x)
|
||||
result = jax.pmap(lambda z: lax.psum(jnp.sin(z), 'i'), devices=devices, axis_name='i')(x)
|
||||
return result + x
|
||||
f_loop = lambda x: lax.fori_loop(0, 3, body, x) # noqa: F821
|
||||
ans = f_loop(jnp.ones(jax.device_count()))
|
||||
ans = f_loop(jnp.ones(count))
|
||||
del body, f_loop
|
||||
|
||||
def body2(i, x):
|
||||
result = jnp.broadcast_to(jnp.sin(x).sum(), x.shape)
|
||||
return result + x
|
||||
g_loop = lambda x: lax.fori_loop(0, 3, body2, x)
|
||||
expected = g_loop(jnp.ones(jax.device_count()))
|
||||
expected = g_loop(jnp.ones(count))
|
||||
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user