mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Change the lowering rule for jax.lax.scan
to avoid emitting a while
loop
when the intent is to fully unroll the loop. PiperOrigin-RevId: 691393597
This commit is contained in:
parent
f1c3109bf5
commit
15a11365e4
@ -418,6 +418,10 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
consts, carry, xs_ = split_list(args, [num_consts, num_carry])
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
num_trips, remainder = divmod(length, unroll)
|
||||
if unroll != 1 and num_trips == 1 and remainder == 0:
|
||||
# In that case, we explicitly want to fully unroll the loop. Put everything
|
||||
# into the remainder block and avoid lowering to a while loop.
|
||||
num_trips, remainder = 0, length
|
||||
if unroll == 1:
|
||||
xss = xs_
|
||||
yss = _map(partial(_empty_array, (length,)), y_avals)
|
||||
|
@ -2424,6 +2424,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
scan = lambda c, xs: lax.scan(f, c, xs)
|
||||
scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2)
|
||||
scan_fully_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=True)
|
||||
|
||||
# jaxprs should be the same size
|
||||
self.assertEqual(
|
||||
@ -2431,9 +2432,19 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
len(str(jax.make_jaxpr(scan_unrolled)(c, xs))))
|
||||
|
||||
# but HLO should grow due to unrolling
|
||||
self.assertLess(
|
||||
len(str(jax.jit(scan).lower(c, xs).as_text('hlo'))),
|
||||
len(str(jax.jit(scan_unrolled).lower(c, xs).as_text('hlo'))))
|
||||
scan_hlo = str(jax.jit(scan).lower(c, xs).as_text("hlo"))
|
||||
scan_unrolled_hlo = str(jax.jit(scan_unrolled).lower(c, xs).as_text("hlo"))
|
||||
scan_fully_unrolled_hlo = str(
|
||||
jax.jit(scan_fully_unrolled).lower(c, xs).as_text("hlo"))
|
||||
|
||||
self.assertLess(len(scan_hlo), len(scan_unrolled_hlo))
|
||||
self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo))
|
||||
|
||||
# and the lowering should contain a while loop, unless the scan is fully
|
||||
# unrolled
|
||||
self.assertIn("while(", scan_hlo)
|
||||
self.assertIn("while(", scan_unrolled_hlo)
|
||||
self.assertNotIn("while(", scan_fully_unrolled_hlo)
|
||||
|
||||
def test_scan_xs_none(self):
|
||||
def f(h, _):
|
||||
|
Loading…
x
Reference in New Issue
Block a user