mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add optional 'forward' argument to lax.scan (#2921)
* add optional 'forward' argument to lax.scan * switch to reverse; revise disable-jit case * fix jaxpr.rst * fix loops.py Co-authored-by: James Bradbury <jekbradbury@gmail.com>
This commit is contained in:
parent
3e522373a0
commit
3cd409ee88
@ -375,8 +375,7 @@ For the example consider the function ``func11`` below
|
||||
...
|
||||
>>> print(make_jaxpr(func11)(onp.ones(16), 5.))
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ forward=True
|
||||
jaxpr={ lambda ; f a b c.
|
||||
let d e = scan[ jaxpr={ lambda ; f a b c.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
g = add e f
|
||||
@ -384,7 +383,8 @@ For the example consider the function ``func11`` below
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a c
|
||||
num_consts=1
|
||||
reverse=False ] b 0.0 a c
|
||||
in (d, e) }
|
||||
|
||||
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
|
||||
|
@ -488,7 +488,7 @@ class _BoundedLoopBuilder(_LoopBuilder):
|
||||
arange_val = jnp.arange(self.start, stop=self.stop, step=self.step)
|
||||
return lax_control_flow.scan_p.bind(*itertools.chain(body_const_vals,
|
||||
init_vals, [arange_val]),
|
||||
forward=True, length=arange_val.shape[0],
|
||||
reverse=False, length=arange_val.shape[0],
|
||||
jaxpr=body_typed_jaxpr,
|
||||
num_consts=len(body_const_vals),
|
||||
num_carry=len(init_vals),
|
||||
|
@ -45,7 +45,7 @@ from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
|
||||
split_dict, cache, extend_name_stack)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
treedef_children, treedef_tuple, tree_leaves,
|
||||
tree_multimap)
|
||||
tree_map, tree_multimap)
|
||||
from jax import ad_util
|
||||
|
||||
xops = xla_client.ops
|
||||
@ -823,7 +823,7 @@ xla.initial_style_translations[cond_p] = _cond_translation_rule
|
||||
|
||||
### scan
|
||||
|
||||
def scan(f, init, xs, length=None):
|
||||
def scan(f, init, xs, length=None, reverse=False):
|
||||
"""Scan a function over leading array axes while carrying along state.
|
||||
|
||||
The type signature in brief is
|
||||
@ -883,6 +883,9 @@ def scan(f, init, xs, length=None):
|
||||
length: optional integer specifying the number of loop iterations, which
|
||||
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
|
||||
be used to perform scans where no input ``xs`` are needed).
|
||||
reverse: optional boolean specifying whether to run the scan iteration
|
||||
forward (the default) or in reverse, equivalent to reversing the leading
|
||||
axes of the arrays in both ``xs`` and in ``ys``.
|
||||
|
||||
Returns:
|
||||
A pair of type ``(c, [b])`` where the first element represents the final
|
||||
@ -921,13 +924,15 @@ def scan(f, init, xs, length=None):
|
||||
if jax.api._jit_is_disabled():
|
||||
carry = init
|
||||
ys = []
|
||||
for i in range(length):
|
||||
maybe_reversed = reversed if reverse else lambda x: x
|
||||
for i in maybe_reversed(range(length)):
|
||||
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
|
||||
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
|
||||
ys.append(y)
|
||||
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
|
||||
else jax.numpy.stack((y, *ys)))
|
||||
return carry, tree_multimap(stack, *ys)
|
||||
ys = tree_multimap(stack, *maybe_reversed(ys))
|
||||
return carry, ys
|
||||
|
||||
carry_avals = tuple(_map(_abstractify, init_flat))
|
||||
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
|
||||
@ -944,12 +949,12 @@ def scan(f, init, xs, length=None):
|
||||
init_tree, carry_avals)
|
||||
|
||||
out = scan_p.bind(*itertools.chain(consts, in_flat),
|
||||
forward=True, length=length, jaxpr=jaxpr,
|
||||
reverse=reverse, length=length, jaxpr=jaxpr,
|
||||
num_consts=len(consts), num_carry=len(init_flat),
|
||||
linear=(False,) * (len(consts) + len(in_flat)))
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
def _scan_impl(*args, forward, length, num_consts, num_carry, jaxpr, linear):
|
||||
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
@ -960,7 +965,7 @@ def _scan_impl(*args, forward, length, num_consts, num_carry, jaxpr, linear):
|
||||
|
||||
def body_fun(vals):
|
||||
[i], carry, ys = split_list(vals, [1, num_carry])
|
||||
i_ = i if forward else length - i - 1
|
||||
i_ = length - i - 1 if reverse else i
|
||||
x = _map(partial(_index_array, i_), x_avals, xs)
|
||||
out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
|
||||
carry_out, y_updates = split_list(out_flat, [num_carry])
|
||||
@ -993,13 +998,13 @@ def _update_array(i, aval, xs, x):
|
||||
else:
|
||||
return lax.dynamic_update_index_in_dim(xs, x, i, 0)
|
||||
|
||||
def _scan_abstract_eval(*args, forward, length, num_consts, num_carry, jaxpr, linear):
|
||||
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
|
||||
if aval is not core.abstract_unit else aval for aval in y_avals]
|
||||
return carry_avals + ys_avals
|
||||
|
||||
def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
|
||||
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
||||
linear):
|
||||
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
@ -1043,7 +1048,7 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
|
||||
|
||||
out_flat = scan_p.bind(
|
||||
*(consts + consts_dot + init + init_dot + xs + xs_dot),
|
||||
forward=forward, length=length, jaxpr=jaxpr_jvp_rearranged,
|
||||
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
|
||||
num_consts=num_consts+len(consts_dot), num_carry=num_carry+len(init_dot),
|
||||
linear=jaxpr_jvp_linear)
|
||||
|
||||
@ -1057,10 +1062,10 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
|
||||
def _prune_zeros(ts):
|
||||
return [t for t in ts if t is not ad_util.zero]
|
||||
|
||||
def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
|
||||
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear):
|
||||
if trace.master.trace_type is pe.StagingJaxprTrace:
|
||||
params = {"forward": forward, "length": length, "num_consts": num_consts,
|
||||
params = {"reverse": reverse, "length": length, "num_consts": num_consts,
|
||||
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
|
||||
return trace.default_process_primitive(scan_p, tracers, params)
|
||||
|
||||
@ -1126,7 +1131,7 @@ def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
|
||||
[lin or uk for uk, lin
|
||||
in zip(unknowns[num_consts:], linear[num_consts:])])
|
||||
out_flat = scan_p.bind(
|
||||
*in_consts, forward=forward, length=length, jaxpr=jaxpr_1_opt,
|
||||
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
|
||||
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1))
|
||||
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
|
||||
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
|
||||
@ -1148,7 +1153,7 @@ def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
|
||||
[False] * len(ext_res_tracers))
|
||||
eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers,
|
||||
out_tracers, scan_p,
|
||||
dict(forward=forward, length=length, jaxpr=jaxpr_2_opt,
|
||||
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
|
||||
num_consts=num_consts_2,
|
||||
num_carry=num_carry, linear=tuple(linear_2)))
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
@ -1160,7 +1165,7 @@ def _promote_aval_rank(sz, aval):
|
||||
else:
|
||||
return ShapedArray((sz,) + aval.shape, aval.dtype)
|
||||
|
||||
def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, linear):
|
||||
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
# we've only implemented transposing scans with specific lin/nonlin patterns
|
||||
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
|
||||
num_ires = len(consts_lin) - sum(consts_lin)
|
||||
@ -1194,7 +1199,7 @@ def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, l
|
||||
[False] * num_eres)
|
||||
|
||||
outs = scan_p.bind(
|
||||
*(ires + ct_consts + ct_carry + ct_ys + eres), forward=not forward,
|
||||
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
|
||||
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
|
||||
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans))
|
||||
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
|
||||
@ -1231,7 +1236,7 @@ def _make_typed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstract
|
||||
return core.TypedJaxpr(jaxpr, consts, in_avals, _map(raise_to_shaped, out_avals))
|
||||
|
||||
|
||||
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
|
||||
def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
|
||||
num_carry, linear):
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
||||
@ -1268,22 +1273,22 @@ def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
|
||||
else x for x, d in zip(xs, xs_bdims)]
|
||||
new_args = new_consts + new_init + new_xs
|
||||
|
||||
outs = scan_p.bind(*new_args, forward=forward, length=length, jaxpr=jaxpr_batched,
|
||||
outs = scan_p.bind(*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear)
|
||||
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
|
||||
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
|
||||
return outs, carry_bdims + ys_bdims
|
||||
|
||||
def _scan_shape_rule(shapes, forward, length, jaxpr,
|
||||
def _scan_shape_rule(shapes, reverse, length, jaxpr,
|
||||
num_consts, num_carry, linear):
|
||||
const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry])
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals]
|
||||
return init_shexprs + ys_shapes
|
||||
|
||||
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
|
||||
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, reverse, length,
|
||||
jaxpr, num_consts, num_carry, linear):
|
||||
out_shape = _scan_shape_rule(shape_exprs, forward, length, jaxpr,
|
||||
out_shape = _scan_shape_rule(shape_exprs, reverse, length, jaxpr,
|
||||
num_consts, num_carry, linear)
|
||||
dynamic_length = length.evaluate(shape_envs.logical)
|
||||
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
|
||||
@ -1292,7 +1297,7 @@ def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
|
||||
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
|
||||
out_vals = scan_p.bind(
|
||||
*itertools.chain([dynamic_length] + consts, [0], init, xs),
|
||||
forward=forward, length=max_length, jaxpr=masked_jaxpr,
|
||||
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
|
||||
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
||||
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear))
|
||||
return out_vals[1:], out_shape
|
||||
@ -1314,7 +1319,7 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
||||
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
||||
|
||||
def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
|
||||
def scan_bind(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
if not core.skip_checks:
|
||||
assert len(linear) == len(args)
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
@ -1326,7 +1331,7 @@ def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
|
||||
carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
|
||||
assert all(_map(typematch, init_avals, carry_avals))
|
||||
core.check_jaxpr(jaxpr.jaxpr)
|
||||
return core.Primitive.bind(scan_p, *args, forward=forward, length=length,
|
||||
return core.Primitive.bind(scan_p, *args, reverse=reverse, length=length,
|
||||
jaxpr=jaxpr, num_consts=num_consts,
|
||||
num_carry=num_carry, linear=linear)
|
||||
|
||||
|
@ -1798,8 +1798,7 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
# TODO(#2640): update docs/jaxpr.rst to reflect new jaxpr
|
||||
self.assertMultiLineStrippedEqual("""
|
||||
{ lambda c ; a b.
|
||||
let d e = scan[ forward=True
|
||||
jaxpr={ lambda ; f a b c.
|
||||
let d e = scan[ jaxpr={ lambda ; f a b c.
|
||||
let d = mul b c
|
||||
e = add a d
|
||||
g = add e f
|
||||
@ -1807,7 +1806,8 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
length=16
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1 ] b 0.0 a c
|
||||
num_consts=1
|
||||
reverse=False ] b 0.0 a c
|
||||
in (d, e) }
|
||||
""", str(jaxpr))
|
||||
|
||||
|
@ -1903,6 +1903,19 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
.format(too_big, api.device_count(), jtu.device_under_test())),
|
||||
lambda: f_loop(np.ones(too_big)))
|
||||
|
||||
def test_scan_reverse(self):
|
||||
def cumsum(x, reverse):
|
||||
return lax.scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
|
||||
|
||||
x = onp.array([3, 1, 4, 1, 5, 9])
|
||||
self.assertAllClose(onp.cumsum(x), cumsum(x, False), check_dtypes=False)
|
||||
self.assertAllClose(onp.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
|
||||
|
||||
with api.disable_jit():
|
||||
self.assertAllClose(onp.cumsum(x), cumsum(x, False), check_dtypes=False)
|
||||
with api.disable_jit():
|
||||
self.assertAllClose(onp.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user