block-unrolled scan primitive implementation (#3738)

* block-unrolled scan implementation, via optional `_unroll` scan parameter

* index statically in the inlined path of lax.scan

* make `unroll` a required scan parameter, and test that it unrolls
This commit is contained in:
Roy Frostig 2020-07-15 11:00:50 -07:00 committed by GitHub
parent 23c279f033
commit 8a62a9b654
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 273 additions and 68 deletions

View File

@ -425,7 +425,8 @@ For the example consider the function ``func11`` below
linear=(False, False, False, False)
num_carry=1
num_consts=1
reverse=False ] b 0.0 a c
reverse=False
unroll=1 ] b 0.0 a c
in (d, e) }
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,

View File

@ -607,9 +607,10 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
for jaxpr in branches),
linear=(*linear, False)), eqn.source_info))
elif eqn.primitive is lax.scan_p:
num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
eqn.params,
["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length"])
["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length",
"unroll"])
# We add the token right at the end of carry
nr_const_and_carry = num_consts + num_carry
new_invars = eqn.invars[0:nr_const_and_carry] + [

View File

@ -493,7 +493,8 @@ class _BoundedLoopBuilder(_LoopBuilder):
num_consts=len(body_const_vals),
num_carry=len(init_vals),
linear=(False,) * (len(body_const_vals) +
len(init_vals) + 1))
len(init_vals) + 1),
unroll=1)
class _CondBuilder(_LoopBuilder):

View File

@ -1096,7 +1096,7 @@ core.custom_typechecks[cond_p] = _cond_typecheck
### scan
def scan(f, init, xs, length=None, reverse=False):
def scan(f, init, xs, length=None, reverse=False, unroll=1):
"""Scan a function over leading array axes while carrying along state.
The type signature in brief is
@ -1159,6 +1159,9 @@ def scan(f, init, xs, length=None, reverse=False):
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
axes of the arrays in both ``xs`` and in ``ys``.
unroll: optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
Returns:
A pair of type ``(c, [b])`` where the first element represents the final
@ -1224,13 +1227,32 @@ def scan(f, init, xs, length=None, reverse=False):
out = scan_p.bind(*itertools.chain(consts, in_flat),
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=len(init_flat),
linear=(False,) * (len(consts) + len(in_flat)))
linear=(False,) * (len(consts) + len(in_flat)),
unroll=unroll)
return tree_unflatten(out_tree, out)
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
carry = init
ys = []
for i in range(length):
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out = f_impl(*consts, *carry, *x)
carry, y = split_list(out, [num_carry])
ys.append(y)
ys = list(reversed(ys)) if reverse else ys
ys = list(zip(*ys))
ys = _map(_stack, y_avals, ys)
return (*carry, *ys)
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
def cond_fun(vals):
i, *_ = vals
@ -1239,8 +1261,8 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
def body_fun(vals):
[i], carry, ys = split_list(vals, [1, num_carry])
i_ = length - i - 1 if reverse else i
x = _map(partial(_index_array, i_), x_avals, xs)
out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
out_flat = f_impl(*consts, *carry, *x)
carry_out, y_updates = split_list(out_flat, [num_carry])
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
return [i + 1] + carry_out + ys_out
@ -1253,12 +1275,112 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
_, *outs = while_loop(cond_fun, body_fun, init_val)
return outs
def _index_array(i, aval, x):
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
linear, block_length, f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, block_length)
assert rem == 0
partition = partial(_partition_leading, num_blocks, block_length)
xs_block = _map(partition, x_avals, xs)
prepend_aval = partial(_prepend_dim_to_aval, block_length)
x_block_avals = _map(prepend_aval, x_avals)
y_block_avals = _map(prepend_aval, y_avals)
f_impl_block = partial(
_scan_impl_unrolled, reverse=reverse, length=block_length,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
outs = _scan_impl_loop(
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
carry, ys_blocks = split_list(outs, [num_carry])
combine = partial(_combine_leading, num_blocks, block_length)
ys = _map(combine, y_avals, ys_blocks)
return (*carry, *ys)
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
unroll):
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
f_impl = core.jaxpr_as_fun(jaxpr)
if unroll == 1:
return _scan_impl_loop(
*args, reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
y_avals=y_avals)
consts, init, xs = split_list(args, [num_consts, num_carry])
num_blocks, rem = divmod(length, unroll)
length_div = num_blocks * unroll
if rem > 0:
if reverse:
split = partial(_split_leading_dim, rem)
xs_rem, xs = unzip2(_map(split, x_avals, xs))
else:
split = partial(_split_leading_dim, length_div)
xs, xs_rem = unzip2(_map(split, x_avals, xs))
outs = _scan_impl_block_unrolled(
*consts, *init, *xs, reverse=reverse, length=length_div,
num_consts=num_consts, num_carry=num_carry, linear=linear,
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys = split_list(outs, [num_carry])
if rem > 0:
outs = _scan_impl_unrolled(
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
num_consts=num_consts, num_carry=num_carry, linear=linear,
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
carry, ys_rem = split_list(outs, [num_carry])
if reverse:
ys = _map(_concatenate, y_avals, ys_rem, ys)
else:
ys = _map(_concatenate, y_avals, ys, ys_rem)
return (*carry, *ys)
def _stack(aval, vals):
if aval is core.abstract_unit:
return core.unit
else:
vals = [lax.expand_dims(x, (0,)) for x in vals]
return lax.concatenate(vals, 0)
def _concatenate(aval, x1, x2):
if aval is core.abstract_unit:
return core.unit
else:
return lax.concatenate([x1, x2], 0)
def _split_leading_dim(i, aval, x):
if aval is core.abstract_unit:
return (core.unit, core.unit)
else:
assert x.ndim >= 1
return (lax.slice_in_dim(x, 0, i),
lax.slice_in_dim(x, i, x.shape[0]))
def _dynamic_index_array(i, aval, x):
if aval is core.abstract_unit:
return core.unit
else:
return lax.dynamic_index_in_dim(x, i, keepdims=False)
def _index_array(i, aval, x):
if aval is core.abstract_unit:
return core.unit
else:
return lax.index_in_dim(x, i, keepdims=False)
def _empty_array(sz, aval):
if aval is core.abstract_unit:
return core.unit
@ -1271,14 +1393,40 @@ def _update_array(i, aval, xs, x):
else:
return lax.dynamic_update_index_in_dim(xs, x, i, 0)
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
def _partition_leading(sz0, sz1, aval, x):
if aval is core.abstract_unit:
return core.unit
else:
assert x.ndim >= 1
assert x.shape[0] == sz0 * sz1
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
def _combine_leading(sz0, sz1, aval, x):
if aval is core.abstract_unit:
return core.unit
else:
assert x.ndim >= 2
assert x.shape[0] == sz0
assert x.shape[1] == sz1
return lax.collapse(x, 0, 2)
def _prepend_dim_to_aval(sz, aval):
if aval is core.abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
return ShapedArray((sz, *aval.shape), aval.dtype)
else:
raise TypeError(f'Prepending dim {sz} to aval {aval}')
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
if aval is not core.abstract_unit else aval for aval in y_avals]
return carry_avals + ys_avals
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
linear):
linear, unroll):
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
num_ys = len(jaxpr.out_avals) - num_carry
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
@ -1322,8 +1470,9 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
out_flat = scan_p.bind(
*(consts + consts_dot + init + init_dot + xs + xs_dot),
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
num_consts=num_consts+len(consts_dot), num_carry=num_carry+len(init_dot),
linear=jaxpr_jvp_linear)
num_consts=num_consts + len(consts_dot),
num_carry=num_carry + len(init_dot),
linear=jaxpr_jvp_linear, unroll=unroll)
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
primals_out = carry + ys
@ -1336,10 +1485,11 @@ def _prune_zeros(ts):
return [t for t in ts if type(t) is not ad_util.Zero]
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear):
jaxpr, linear, unroll):
if trace.master.trace_type is pe.StagingJaxprTrace:
params = {"reverse": reverse, "length": length, "num_consts": num_consts,
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
params = dict(reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
unroll=unroll)
return trace.default_process_primitive(scan_p, tracers, params)
num_ys = len(jaxpr.out_avals) - num_carry
@ -1404,7 +1554,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
in zip(unknowns[num_consts:], linear[num_consts:])])
out_flat = scan_p.bind(
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1))
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1),
unroll=unroll)
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
@ -1427,7 +1578,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
out_tracers, scan_p,
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
num_consts=num_consts_2,
num_carry=num_carry, linear=tuple(linear_2)),
num_carry=num_carry, linear=tuple(linear_2),
unroll=unroll),
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
@ -1438,7 +1590,8 @@ def _promote_aval_rank(sz, aval):
else:
return ShapedArray((sz,) + aval.shape, aval.dtype)
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear):
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
@ -1474,7 +1627,8 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, l
outs = scan_p.bind(
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans))
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
unroll=unroll)
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
@ -1510,7 +1664,7 @@ def _make_typed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstract
def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
num_carry, linear):
num_carry, linear, unroll):
num_ys = len(jaxpr.out_avals) - num_carry
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
orig_batched = [d is not batching.not_mapped for d in dims]
@ -1546,14 +1700,15 @@ def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
else x for x, d in zip(xs, xs_bdims)]
new_args = new_consts + new_init + new_xs
outs = scan_p.bind(*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear)
outs = scan_p.bind(
*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
return outs, carry_bdims + ys_bdims
def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
jaxpr, num_consts, num_carry, linear):
jaxpr, num_consts, num_carry, linear, unroll):
dynamic_length, = masking.shape_as_value((length,))
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
@ -1563,7 +1718,8 @@ def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
*itertools.chain([dynamic_length] + consts, [0], init, xs),
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
num_consts=1 + num_consts, num_carry=1 + num_carry,
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear))
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear),
unroll=unroll)
return out_vals[1:]
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
@ -1584,10 +1740,16 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def _scan_typecheck(*avals, reverse, length, num_consts, num_carry, jaxpr,
linear):
linear, unroll):
core.typecheck_assert(
len(linear) == len(avals),
f'scan called with {len(linear)} linear flags for {len(avals)} operands')
core.typecheck_assert(
isinstance(unroll, int),
f'unroll length {unroll} is not an int')
core.typecheck_assert(
unroll > 0,
f'non-positive unroll length {unroll}')
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(

View File

@ -60,6 +60,12 @@ COND_IMPLS = [
]
SCAN_IMPLS = [
(lax.scan, 'unroll1'),
(partial(lax.scan, unroll=2), 'unroll2'),
]
def while_loop_reference(cond, body, carry):
while cond(carry):
carry = body(carry)
@ -1278,11 +1284,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(out, (7, 10))
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
jit_scan, jit_f, scan_name),
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
for jit_scan in [False, True]
for jit_f in [False, True])
def testScanImpl(self, jit_scan, jit_f):
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS)
def testScanImpl(self, jit_scan, jit_f, scan):
rng = np.random.RandomState(0)
d = rng.randn(2)
@ -1297,9 +1305,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
scan = api.jit(scan, (0,))
as_ = rng.randn(5, 3)
c = rng.randn(4)
@ -1309,11 +1315,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
jit_scan, jit_f, scan_name),
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
for jit_scan in [False, True]
for jit_f in [False, True])
def testScanJVP(self, jit_scan, jit_f):
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS)
def testScanJVP(self, jit_scan, jit_f, scan):
rng = np.random.RandomState(0)
d = rng.randn(2)
@ -1328,9 +1336,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
scan = api.jit(scan, (0,))
as_ = rng.randn(5, 3)
c = rng.randn(4)
@ -1343,11 +1349,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"])
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
jit_scan, jit_f, scan_name),
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
for jit_scan in [False, True]
for jit_f in [False, True])
def testScanLinearize(self, jit_scan, jit_f):
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS)
def testScanLinearize(self, jit_scan, jit_f, scan):
rng = np.random.RandomState(0)
d = rng.randn(2)
@ -1362,9 +1370,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
scan = api.jit(scan, (0,))
as_ = rng.randn(5, 3)
c = rng.randn(4)
@ -1375,12 +1381,14 @@ class LaxControlFlowTest(jtu.JaxTestCase):
rtol={np.float64: 1e-14})
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
"jit_scan": jit_scan, "jit_f": jit_f}
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
jit_scan, jit_f, scan_name),
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
for jit_scan in [False, True]
for jit_f in [False, True])
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testScanGrad(self, jit_scan, jit_f):
def testScanGrad(self, jit_scan, jit_f, scan):
rng = np.random.RandomState(0)
d = rng.randn(2)
@ -1395,9 +1403,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, static_argnums=(0,))
else:
scan = lax.scan
scan = api.jit(scan, static_argnums=(0,))
as_ = rng.randn(5, 3)
c = rng.randn(4)
@ -1543,7 +1549,11 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.scan(lambda c, x: (0, x), (1, 2), jnp.arange(5))
def testScanHigherOrderDifferentiation(self):
@parameterized.named_parameters(
{"testcase_name": "_{}".format(scan_name),
"scan": scan_impl}
for scan_impl, scan_name in SCAN_IMPLS)
def testScanHigherOrderDifferentiation(self, scan):
d = 0.75
def f(c, a):
b = jnp.sin(c * jnp.sum(jnp.cos(d * a)))
@ -1553,18 +1563,20 @@ class LaxControlFlowTest(jtu.JaxTestCase):
as_ = jnp.arange(6.).reshape((3, 2))
c = 1.
jtu.check_grads(lambda c, as_: lax.scan(f, c, as_), (c, as_),
jtu.check_grads(lambda c, as_: scan(f, c, as_), (c, as_),
modes=["rev"], order=2, rtol={np.float32: 6e-3})
@parameterized.named_parameters(
{"testcase_name": "_jit_scan={}_jit_f={}_in_axes={}".format(
jit_scan, jit_f, in_axes),
"jit_scan": jit_scan, "jit_f": jit_f, "in_axes": in_axes}
{"testcase_name": "_jit_scan={}_jit_f={}_in_axes={}_impl={}".format(
jit_scan, jit_f, in_axes, scan_name),
"jit_scan": jit_scan, "jit_f": jit_f, "in_axes": in_axes,
"scan": scan_impl}
for jit_scan in [False, True]
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS
for in_axes in itertools.product([None, 0, 1], [None, 0, 1, 2])
if in_axes != (None, None))
def testScanVmap(self, jit_scan, jit_f, in_axes):
def testScanVmap(self, jit_scan, jit_f, in_axes, scan):
rng = np.random.RandomState(0)
d = rng.randn(2)
@ -1579,9 +1591,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
if jit_f:
f = api.jit(f)
if jit_scan:
scan = api.jit(lax.scan, (0,))
else:
scan = lax.scan
scan = api.jit(scan, (0,))
as_shape = [5, 3]
c_shape = [4]
@ -2298,9 +2308,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
.format(too_big, api.device_count(), jtu.device_under_test())),
lambda: f_loop(jnp.ones(too_big)))
def test_scan_reverse(self):
@parameterized.named_parameters(
{"testcase_name": "_{}".format(scan_name),
"scan": scan_impl}
for scan_impl, scan_name in SCAN_IMPLS)
def test_scan_reverse(self, scan):
def cumsum(x, reverse):
return lax.scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
return scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
x = np.array([3, 1, 4, 1, 5, 9])
self.assertAllClose(np.cumsum(x), cumsum(x, False), check_dtypes=False)
@ -2311,6 +2325,32 @@ class LaxControlFlowTest(jtu.JaxTestCase):
with api.disable_jit():
self.assertAllClose(np.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
def test_scan_unroll(self):
d = jnp.ones(2)
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
c = jnp.sin(c * b)
assert b.shape == ()
return c, b
xs = jnp.ones((5, 3))
c = jnp.ones(4)
scan = lambda c, xs: lax.scan(f, c, xs)
scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2)
# jaxprs should be the same size
self.assertEqual(
len(str(api.make_jaxpr(scan)(c, xs))),
len(str(api.make_jaxpr(scan_unrolled)(c, xs))))
# but HLO should grow due to unrolling
self.assertLess(
len(str(api.xla_computation(scan)(c, xs).as_hlo_text())),
len(str(api.xla_computation(scan_unrolled)(c, xs).as_hlo_text())))
def test_disable_jit_cond_with_vmap(self):
# https://github.com/google/jax/issues/3093
def fn(t):