diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 39da45079..3871a87a7 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2445,7 +2445,7 @@ class LaxControlFlowTest(jtu.JaxTestCase): assert b.shape == () return c, b - xs = jnp.ones((5, 3)) + xs = jnp.ones((20, 3)) c = jnp.ones(4) scan = lambda c, xs: lax.scan(f, c, xs)