Change the lowering rule for jax.lax.scan to avoid emitting a while loop

when the intent is to fully unroll the loop.

PiperOrigin-RevId: 691393597
This commit is contained in:
Benjamin Chetioui 2024-10-30 06:19:54 -07:00 committed by jax authors
parent f1c3109bf5
commit 15a11365e4
2 changed files with 18 additions and 3 deletions

View File

@ -418,6 +418,10 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
num_trips, remainder = divmod(length, unroll)
if unroll != 1 and num_trips == 1 and remainder == 0:
# In that case, we explicitly want to fully unroll the loop. Put everything
# into the remainder block and avoid lowering to a while loop.
num_trips, remainder = 0, length
if unroll == 1:
xss = xs_
yss = _map(partial(_empty_array, (length,)), y_avals)

View File

@ -2424,6 +2424,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
scan = lambda c, xs: lax.scan(f, c, xs)
scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2)
scan_fully_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=True)
# jaxprs should be the same size
self.assertEqual(
@ -2431,9 +2432,19 @@ class LaxControlFlowTest(jtu.JaxTestCase):
len(str(jax.make_jaxpr(scan_unrolled)(c, xs))))
# but HLO should grow due to unrolling
self.assertLess(
len(str(jax.jit(scan).lower(c, xs).as_text('hlo'))),
len(str(jax.jit(scan_unrolled).lower(c, xs).as_text('hlo'))))
scan_hlo = str(jax.jit(scan).lower(c, xs).as_text("hlo"))
scan_unrolled_hlo = str(jax.jit(scan_unrolled).lower(c, xs).as_text("hlo"))
scan_fully_unrolled_hlo = str(
jax.jit(scan_fully_unrolled).lower(c, xs).as_text("hlo"))
self.assertLess(len(scan_hlo), len(scan_unrolled_hlo))
self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo))
# and the lowering should contain a while loop, unless the scan is fully
# unrolled
self.assertIn("while(", scan_hlo)
self.assertIn("while(", scan_unrolled_hlo)
self.assertNotIn("while(", scan_fully_unrolled_hlo)
def test_scan_xs_none(self):
def f(h, _):