add optional 'forward' argument to lax.scan (#2921)

* add optional 'forward' argument to lax.scan

* switch to reverse; revise disable-jit case

* fix jaxpr.rst

* fix loops.py

Co-authored-by: James Bradbury <jekbradbury@gmail.com>
This commit is contained in:
Matthew Johnson 2020-05-04 19:44:22 -07:00 committed by GitHub
parent 3e522373a0
commit 3cd409ee88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 31 deletions

View File

@ -375,8 +375,7 @@ For the example consider the function ``func11`` below
...
>>> print(make_jaxpr(func11)(onp.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; f a b c.
let d e = scan[ jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
@ -384,7 +383,8 @@ For the example consider the function ``func11`` below
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1 ] b 0.0 a c
num_consts=1
reverse=False ] b 0.0 a c
in (d, e) }
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,

View File

@ -488,7 +488,7 @@ class _BoundedLoopBuilder(_LoopBuilder):
arange_val = jnp.arange(self.start, stop=self.stop, step=self.step)
return lax_control_flow.scan_p.bind(*itertools.chain(body_const_vals,
init_vals, [arange_val]),
forward=True, length=arange_val.shape[0],
reverse=False, length=arange_val.shape[0],
jaxpr=body_typed_jaxpr,
num_consts=len(body_const_vals),
num_carry=len(init_vals),

View File

@ -45,7 +45,7 @@ from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,
split_dict, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_leaves,
tree_multimap)
tree_map, tree_multimap)
from jax import ad_util
xops = xla_client.ops
@ -823,7 +823,7 @@ xla.initial_style_translations[cond_p] = _cond_translation_rule
### scan
def scan(f, init, xs, length=None):
def scan(f, init, xs, length=None, reverse=False):
"""Scan a function over leading array axes while carrying along state.
The type signature in brief is
@ -883,6 +883,9 @@ def scan(f, init, xs, length=None):
length: optional integer specifying the number of loop iterations, which
must agree with the sizes of leading axes of the arrays in ``xs`` (but can
be used to perform scans where no input ``xs`` are needed).
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
Returns:
A pair of type ``(c, [b])`` where the first element represents the final
@ -921,13 +924,15 @@ def scan(f, init, xs, length=None):
if jax.api._jit_is_disabled():
carry = init
ys = []
for i in range(length):
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat]
carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
ys.append(y)
stack = lambda y, *ys: (y if core.get_aval(y) is core.abstract_unit
else jax.numpy.stack((y, *ys)))
return carry, tree_multimap(stack, *ys)
ys = tree_multimap(stack, *maybe_reversed(ys))
return carry, ys
carry_avals = tuple(_map(_abstractify, init_flat))
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
@ -944,12 +949,12 @@ def scan(f, init, xs, length=None):
init_tree, carry_avals)
out = scan_p.bind(*itertools.chain(consts, in_flat),
forward=True, length=length, jaxpr=jaxpr,
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=len(init_flat),
linear=(False,) * (len(consts) + len(in_flat)))
return tree_unflatten(out_tree, out)
def _scan_impl(*args, forward, length, num_consts, num_carry, jaxpr, linear):
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
consts, init, xs = split_list(args, [num_consts, num_carry])
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
@ -960,7 +965,7 @@ def _scan_impl(*args, forward, length, num_consts, num_carry, jaxpr, linear):
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = i if forward else length - i - 1
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
carry_out, y_updates = split_list(out_flat, [num_carry])
@ -993,13 +998,13 @@ def _update_array(i, aval, xs, x):
else:
return lax.dynamic_update_index_in_dim(xs, x, i, 0)
def _scan_abstract_eval(*args, forward, length, num_consts, num_carry, jaxpr, linear):
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
if aval is not core.abstract_unit else aval for aval in y_avals]
return carry_avals + ys_avals
def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear):
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
@ -1043,7 +1048,7 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
out_flat = scan_p.bind(
*(consts + consts_dot + init + init_dot + xs + xs_dot),
forward=forward, length=length, jaxpr=jaxpr_jvp_rearranged,
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
num_consts=num_consts+len(consts_dot), num_carry=num_carry+len(init_dot),
linear=jaxpr_jvp_linear)
@ -1057,10 +1062,10 @@ def _scan_jvp(primals, tangents, forward, length, jaxpr, num_consts, num_carry,
def _prune_zeros(ts):
return [t for t in ts if t is not ad_util.zero]
def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear):
if trace.master.trace_type is pe.StagingJaxprTrace:
params = {"forward": forward, "length": length, "num_consts": num_consts,
params = {"reverse": reverse, "length": length, "num_consts": num_consts,
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
return trace.default_process_primitive(scan_p, tracers, params)
@ -1126,7 +1131,7 @@ def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
[lin or uk for uk, lin
in zip(unknowns[num_consts:], linear[num_consts:])])
out_flat = scan_p.bind(
*in_consts, forward=forward, length=length, jaxpr=jaxpr_1_opt,
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1))
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
@ -1148,7 +1153,7 @@ def _scan_partial_eval(trace, *tracers, forward, length, num_consts, num_carry,
[False] * len(ext_res_tracers))
eqn = pe.new_eqn_recipe(int_res_tracers + new_tracers + ext_res_tracers,
out_tracers, scan_p,
dict(forward=forward, length=length, jaxpr=jaxpr_2_opt,
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
num_consts=num_consts_2,
num_carry=num_carry, linear=tuple(linear_2)))
for t in out_tracers: t.recipe = eqn
@ -1160,7 +1165,7 @@ def _promote_aval_rank(sz, aval):
else:
return ShapedArray((sz,) + aval.shape, aval.dtype)
def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, linear):
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
@ -1194,7 +1199,7 @@ def _scan_transpose(cts, *args, forward, length, num_consts, num_carry, jaxpr, l
[False] * num_eres)
outs = scan_p.bind(
*(ires + ct_consts + ct_carry + ct_ys + eres), forward=not forward,
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans))
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
@ -1231,7 +1236,7 @@ def _make_typed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstract
return core.TypedJaxpr(jaxpr, consts, in_avals, _map(raise_to_shaped, out_avals))
def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
num_carry, linear):
num_ys = len(jaxpr.out_avals) - num_carry
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
@ -1268,22 +1273,22 @@ def _scan_batching_rule(args, dims, forward, length, jaxpr, num_consts,
else x for x, d in zip(xs, xs_bdims)]
new_args = new_consts + new_init + new_xs
outs = scan_p.bind(*new_args, forward=forward, length=length, jaxpr=jaxpr_batched,
outs = scan_p.bind(*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear)
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
return outs, carry_bdims + ys_bdims
def _scan_shape_rule(shapes, forward, length, jaxpr,
def _scan_shape_rule(shapes, reverse, length, jaxpr,
num_consts, num_carry, linear):
const_shexprs, init_shexprs, xs_shexprs = split_list(shapes, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_shapes = [(length,) + tuple(y_aval.shape) for y_aval in y_avals]
return init_shexprs + ys_shapes
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, reverse, length,
jaxpr, num_consts, num_carry, linear):
out_shape = _scan_shape_rule(shape_exprs, forward, length, jaxpr,
out_shape = _scan_shape_rule(shape_exprs, reverse, length, jaxpr,
num_consts, num_carry, linear)
dynamic_length = length.evaluate(shape_envs.logical)
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
@ -1292,7 +1297,7 @@ def _scan_masking_rule(shape_envs, padded_vals, shape_exprs, forward, length,
const_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry])
out_vals = scan_p.bind(
*itertools.chain([dynamic_length] + consts, [0], init, xs),
forward=forward, length=max_length, jaxpr=masked_jaxpr,
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
num_consts=1 + num_consts, num_carry=1 + num_carry,
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear))
return out_vals[1:], out_shape
@ -1314,7 +1319,7 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
def scan_bind(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
if not core.skip_checks:
assert len(linear) == len(args)
consts, init, xs = split_list(args, [num_consts, num_carry])
@ -1326,7 +1331,7 @@ def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
carry_avals, _ = split_list(jaxpr.out_avals, [num_carry])
assert all(_map(typematch, init_avals, carry_avals))
core.check_jaxpr(jaxpr.jaxpr)
return core.Primitive.bind(scan_p, *args, forward=forward, length=length,
return core.Primitive.bind(scan_p, *args, reverse=reverse, length=length,
jaxpr=jaxpr, num_consts=num_consts,
num_carry=num_carry, linear=linear)

View File

@ -1798,8 +1798,7 @@ class JaxprTest(jtu.JaxTestCase):
# TODO(#2640): update docs/jaxpr.rst to reflect new jaxpr
self.assertMultiLineStrippedEqual("""
{ lambda c ; a b.
let d e = scan[ forward=True
jaxpr={ lambda ; f a b c.
let d e = scan[ jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
@ -1807,7 +1806,8 @@ class JaxprTest(jtu.JaxTestCase):
length=16
linear=(False, False, False, False)
num_carry=1
num_consts=1 ] b 0.0 a c
num_consts=1
reverse=False ] b 0.0 a c
in (d, e) }
""", str(jaxpr))

View File

@ -1903,6 +1903,19 @@ class LaxControlFlowTest(jtu.JaxTestCase):
.format(too_big, api.device_count(), jtu.device_under_test())),
lambda: f_loop(np.ones(too_big)))
def test_scan_reverse(self):
def cumsum(x, reverse):
return lax.scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
x = onp.array([3, 1, 4, 1, 5, 9])
self.assertAllClose(onp.cumsum(x), cumsum(x, False), check_dtypes=False)
self.assertAllClose(onp.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
with api.disable_jit():
self.assertAllClose(onp.cumsum(x), cumsum(x, False), check_dtypes=False)
with api.disable_jit():
self.assertAllClose(onp.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
if __name__ == '__main__':
absltest.main()