Merge pull request #27001 from mattjj:yash-scan

PiperOrigin-RevId: 734685031
This commit is contained in:
jax authors 2025-03-07 14:14:30 -08:00
commit d849779689

View File

@ -457,8 +457,11 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
else:
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals)
if num_trips:
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals)
else:
yss = _map(partial(_empty_array, (num_trips * unroll,), (None,)), y_avals)
def inner(n, carry, xs):
ys = []
@ -493,7 +496,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
if num_trips:
i = lax._const(num_trips, 0)
_, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss))
if unroll != 1:
if unroll != 1 and num_trips != 0:
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
else:
ys = yss