mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #27001 from mattjj:yash-scan
PiperOrigin-RevId: 734685031
This commit is contained in:
commit
d849779689
@ -457,8 +457,11 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
xs_, xs_rem = unzip2(_map(partial(_split_leading, num_trips*unroll), xs_))
|
||||
else:
|
||||
xs_rem, xs_ = unzip2(_map(partial(_split_leading, remainder), xs_))
|
||||
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
|
||||
yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals)
|
||||
if num_trips:
|
||||
xss = [lax.reshape(x, (num_trips, unroll, *x.shape[1:])) for x in xs_]
|
||||
yss = _map(partial(_empty_array, (num_trips, unroll), (None, None)), y_avals)
|
||||
else:
|
||||
yss = _map(partial(_empty_array, (num_trips * unroll,), (None,)), y_avals)
|
||||
|
||||
def inner(n, carry, xs):
|
||||
ys = []
|
||||
@ -493,7 +496,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
if num_trips:
|
||||
i = lax._const(num_trips, 0)
|
||||
_, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss))
|
||||
if unroll != 1:
|
||||
if unroll != 1 and num_trips != 0:
|
||||
ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss]
|
||||
else:
|
||||
ys = yss
|
||||
|
Loading…
x
Reference in New Issue
Block a user