mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
block-unrolled scan primitive implementation (#3738)
* block-unrolled scan implementation, via optional `_unroll` scan parameter * index statically in the inlined path of lax.scan * make `unroll` a required scan parameter, and test that it unrolls
This commit is contained in:
parent
23c279f033
commit
8a62a9b654
@ -425,7 +425,8 @@ For the example consider the function ``func11`` below
|
||||
linear=(False, False, False, False)
|
||||
num_carry=1
|
||||
num_consts=1
|
||||
reverse=False ] b 0.0 a c
|
||||
reverse=False
|
||||
unroll=1 ] b 0.0 a c
|
||||
in (d, e) }
|
||||
|
||||
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
|
||||
|
@ -607,9 +607,10 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
for jaxpr in branches),
|
||||
linear=(*linear, False)), eqn.source_info))
|
||||
elif eqn.primitive is lax.scan_p:
|
||||
num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
|
||||
num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict(
|
||||
eqn.params,
|
||||
["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length"])
|
||||
["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length",
|
||||
"unroll"])
|
||||
# We add the token right at the end of carry
|
||||
nr_const_and_carry = num_consts + num_carry
|
||||
new_invars = eqn.invars[0:nr_const_and_carry] + [
|
||||
|
@ -493,7 +493,8 @@ class _BoundedLoopBuilder(_LoopBuilder):
|
||||
num_consts=len(body_const_vals),
|
||||
num_carry=len(init_vals),
|
||||
linear=(False,) * (len(body_const_vals) +
|
||||
len(init_vals) + 1))
|
||||
len(init_vals) + 1),
|
||||
unroll=1)
|
||||
|
||||
|
||||
class _CondBuilder(_LoopBuilder):
|
||||
|
@ -1096,7 +1096,7 @@ core.custom_typechecks[cond_p] = _cond_typecheck
|
||||
|
||||
### scan
|
||||
|
||||
def scan(f, init, xs, length=None, reverse=False):
|
||||
def scan(f, init, xs, length=None, reverse=False, unroll=1):
|
||||
"""Scan a function over leading array axes while carrying along state.
|
||||
|
||||
The type signature in brief is
|
||||
@ -1159,6 +1159,9 @@ def scan(f, init, xs, length=None, reverse=False):
|
||||
reverse: optional boolean specifying whether to run the scan iteration
|
||||
forward (the default) or in reverse, equivalent to reversing the leading
|
||||
axes of the arrays in both ``xs`` and in ``ys``.
|
||||
unroll: optional positive int specifying, in the underlying operation of the
|
||||
scan primitive, how many scan iterations to unroll within a single
|
||||
iteration of a loop.
|
||||
|
||||
Returns:
|
||||
A pair of type ``(c, [b])`` where the first element represents the final
|
||||
@ -1224,13 +1227,32 @@ def scan(f, init, xs, length=None, reverse=False):
|
||||
out = scan_p.bind(*itertools.chain(consts, in_flat),
|
||||
reverse=reverse, length=length, jaxpr=jaxpr,
|
||||
num_consts=len(consts), num_carry=len(init_flat),
|
||||
linear=(False,) * (len(consts) + len(in_flat)))
|
||||
linear=(False,) * (len(consts) + len(in_flat)),
|
||||
unroll=unroll)
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
|
||||
f_impl, x_avals, y_avals):
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
|
||||
carry = init
|
||||
ys = []
|
||||
|
||||
for i in range(length):
|
||||
i_ = length - i - 1 if reverse else i
|
||||
x = _map(partial(_index_array, i_), x_avals, xs)
|
||||
out = f_impl(*consts, *carry, *x)
|
||||
carry, y = split_list(out, [num_carry])
|
||||
ys.append(y)
|
||||
|
||||
ys = list(reversed(ys)) if reverse else ys
|
||||
ys = list(zip(*ys))
|
||||
ys = _map(_stack, y_avals, ys)
|
||||
return (*carry, *ys)
|
||||
|
||||
def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear,
|
||||
f_impl, x_avals, y_avals):
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
|
||||
def cond_fun(vals):
|
||||
i, *_ = vals
|
||||
@ -1239,8 +1261,8 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
def body_fun(vals):
|
||||
[i], carry, ys = split_list(vals, [1, num_carry])
|
||||
i_ = length - i - 1 if reverse else i
|
||||
x = _map(partial(_index_array, i_), x_avals, xs)
|
||||
out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x))
|
||||
x = _map(partial(_dynamic_index_array, i_), x_avals, xs)
|
||||
out_flat = f_impl(*consts, *carry, *x)
|
||||
carry_out, y_updates = split_list(out_flat, [num_carry])
|
||||
ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates)
|
||||
return [i + 1] + carry_out + ys_out
|
||||
@ -1253,12 +1275,112 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
_, *outs = while_loop(cond_fun, body_fun, init_val)
|
||||
return outs
|
||||
|
||||
def _index_array(i, aval, x):
|
||||
def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry,
|
||||
linear, block_length, f_impl, x_avals, y_avals):
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
|
||||
num_blocks, rem = divmod(length, block_length)
|
||||
assert rem == 0
|
||||
|
||||
partition = partial(_partition_leading, num_blocks, block_length)
|
||||
xs_block = _map(partition, x_avals, xs)
|
||||
|
||||
prepend_aval = partial(_prepend_dim_to_aval, block_length)
|
||||
x_block_avals = _map(prepend_aval, x_avals)
|
||||
y_block_avals = _map(prepend_aval, y_avals)
|
||||
|
||||
f_impl_block = partial(
|
||||
_scan_impl_unrolled, reverse=reverse, length=block_length,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
||||
|
||||
outs = _scan_impl_loop(
|
||||
*consts, *init, *xs_block, reverse=reverse, length=num_blocks,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals)
|
||||
|
||||
carry, ys_blocks = split_list(outs, [num_carry])
|
||||
combine = partial(_combine_leading, num_blocks, block_length)
|
||||
ys = _map(combine, y_avals, ys_blocks)
|
||||
return (*carry, *ys)
|
||||
|
||||
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
|
||||
unroll):
|
||||
_, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
|
||||
_, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
f_impl = core.jaxpr_as_fun(jaxpr)
|
||||
|
||||
if unroll == 1:
|
||||
return _scan_impl_loop(
|
||||
*args, reverse=reverse, length=length, num_consts=num_consts,
|
||||
num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals,
|
||||
y_avals=y_avals)
|
||||
|
||||
consts, init, xs = split_list(args, [num_consts, num_carry])
|
||||
num_blocks, rem = divmod(length, unroll)
|
||||
length_div = num_blocks * unroll
|
||||
|
||||
if rem > 0:
|
||||
if reverse:
|
||||
split = partial(_split_leading_dim, rem)
|
||||
xs_rem, xs = unzip2(_map(split, x_avals, xs))
|
||||
else:
|
||||
split = partial(_split_leading_dim, length_div)
|
||||
xs, xs_rem = unzip2(_map(split, x_avals, xs))
|
||||
|
||||
outs = _scan_impl_block_unrolled(
|
||||
*consts, *init, *xs, reverse=reverse, length=length_div,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
||||
|
||||
carry, ys = split_list(outs, [num_carry])
|
||||
|
||||
if rem > 0:
|
||||
outs = _scan_impl_unrolled(
|
||||
*consts, *carry, *xs_rem, reverse=reverse, length=rem,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear,
|
||||
f_impl=f_impl, x_avals=x_avals, y_avals=y_avals)
|
||||
carry, ys_rem = split_list(outs, [num_carry])
|
||||
if reverse:
|
||||
ys = _map(_concatenate, y_avals, ys_rem, ys)
|
||||
else:
|
||||
ys = _map(_concatenate, y_avals, ys, ys_rem)
|
||||
|
||||
return (*carry, *ys)
|
||||
|
||||
def _stack(aval, vals):
|
||||
if aval is core.abstract_unit:
|
||||
return core.unit
|
||||
else:
|
||||
vals = [lax.expand_dims(x, (0,)) for x in vals]
|
||||
return lax.concatenate(vals, 0)
|
||||
|
||||
def _concatenate(aval, x1, x2):
|
||||
if aval is core.abstract_unit:
|
||||
return core.unit
|
||||
else:
|
||||
return lax.concatenate([x1, x2], 0)
|
||||
|
||||
def _split_leading_dim(i, aval, x):
|
||||
if aval is core.abstract_unit:
|
||||
return (core.unit, core.unit)
|
||||
else:
|
||||
assert x.ndim >= 1
|
||||
return (lax.slice_in_dim(x, 0, i),
|
||||
lax.slice_in_dim(x, i, x.shape[0]))
|
||||
|
||||
def _dynamic_index_array(i, aval, x):
|
||||
if aval is core.abstract_unit:
|
||||
return core.unit
|
||||
else:
|
||||
return lax.dynamic_index_in_dim(x, i, keepdims=False)
|
||||
|
||||
def _index_array(i, aval, x):
|
||||
if aval is core.abstract_unit:
|
||||
return core.unit
|
||||
else:
|
||||
return lax.index_in_dim(x, i, keepdims=False)
|
||||
|
||||
def _empty_array(sz, aval):
|
||||
if aval is core.abstract_unit:
|
||||
return core.unit
|
||||
@ -1271,14 +1393,40 @@ def _update_array(i, aval, xs, x):
|
||||
else:
|
||||
return lax.dynamic_update_index_in_dim(xs, x, i, 0)
|
||||
|
||||
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
def _partition_leading(sz0, sz1, aval, x):
|
||||
if aval is core.abstract_unit:
|
||||
return core.unit
|
||||
else:
|
||||
assert x.ndim >= 1
|
||||
assert x.shape[0] == sz0 * sz1
|
||||
return lax.reshape(x, (sz0, sz1, *x.shape[1:]))
|
||||
|
||||
def _combine_leading(sz0, sz1, aval, x):
|
||||
if aval is core.abstract_unit:
|
||||
return core.unit
|
||||
else:
|
||||
assert x.ndim >= 2
|
||||
assert x.shape[0] == sz0
|
||||
assert x.shape[1] == sz1
|
||||
return lax.collapse(x, 0, 2)
|
||||
|
||||
def _prepend_dim_to_aval(sz, aval):
|
||||
if aval is core.abstract_unit:
|
||||
return aval
|
||||
elif isinstance(aval, ShapedArray):
|
||||
return ShapedArray((sz, *aval.shape), aval.dtype)
|
||||
else:
|
||||
raise TypeError(f'Prepending dim {sz} to aval {aval}')
|
||||
|
||||
def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll):
|
||||
carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
|
||||
ys_avals = [ShapedArray((length,) + aval.shape, aval.dtype)
|
||||
if aval is not core.abstract_unit else aval for aval in y_avals]
|
||||
return carry_avals + ys_avals
|
||||
|
||||
def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
||||
linear):
|
||||
linear, unroll):
|
||||
num_xs = len(jaxpr.in_avals) - num_carry - num_consts
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
nonzeros = [type(t) is not ad_util.Zero for t in tangents]
|
||||
@ -1322,8 +1470,9 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry,
|
||||
out_flat = scan_p.bind(
|
||||
*(consts + consts_dot + init + init_dot + xs + xs_dot),
|
||||
reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged,
|
||||
num_consts=num_consts+len(consts_dot), num_carry=num_carry+len(init_dot),
|
||||
linear=jaxpr_jvp_linear)
|
||||
num_consts=num_consts + len(consts_dot),
|
||||
num_carry=num_carry + len(init_dot),
|
||||
linear=jaxpr_jvp_linear, unroll=unroll)
|
||||
|
||||
carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys])
|
||||
primals_out = carry + ys
|
||||
@ -1336,10 +1485,11 @@ def _prune_zeros(ts):
|
||||
return [t for t in ts if type(t) is not ad_util.Zero]
|
||||
|
||||
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear):
|
||||
jaxpr, linear, unroll):
|
||||
if trace.master.trace_type is pe.StagingJaxprTrace:
|
||||
params = {"reverse": reverse, "length": length, "num_consts": num_consts,
|
||||
"num_carry": num_carry, "jaxpr": jaxpr, "linear": linear}
|
||||
params = dict(reverse=reverse, length=length, num_consts=num_consts,
|
||||
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
|
||||
unroll=unroll)
|
||||
return trace.default_process_primitive(scan_p, tracers, params)
|
||||
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
@ -1404,7 +1554,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
in zip(unknowns[num_consts:], linear[num_consts:])])
|
||||
out_flat = scan_p.bind(
|
||||
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
|
||||
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1))
|
||||
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1),
|
||||
unroll=unroll)
|
||||
out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
|
||||
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]
|
||||
|
||||
@ -1427,7 +1578,8 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
out_tracers, scan_p,
|
||||
dict(reverse=reverse, length=length, jaxpr=jaxpr_2_opt,
|
||||
num_consts=num_consts_2,
|
||||
num_carry=num_carry, linear=tuple(linear_2)),
|
||||
num_carry=num_carry, linear=tuple(linear_2),
|
||||
unroll=unroll),
|
||||
source_info_util.current())
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
@ -1438,7 +1590,8 @@ def _promote_aval_rank(sz, aval):
|
||||
else:
|
||||
return ShapedArray((sz,) + aval.shape, aval.dtype)
|
||||
|
||||
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear):
|
||||
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll):
|
||||
# we've only implemented transposing scans with specific lin/nonlin patterns
|
||||
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
|
||||
num_ires = len(consts_lin) - sum(consts_lin)
|
||||
@ -1474,7 +1627,8 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, l
|
||||
outs = scan_p.bind(
|
||||
*(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse,
|
||||
length=length, jaxpr=jaxpr_trans, num_consts=num_ires,
|
||||
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans))
|
||||
num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans),
|
||||
unroll=unroll)
|
||||
ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry])
|
||||
return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
|
||||
|
||||
@ -1510,7 +1664,7 @@ def _make_typed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.Abstract
|
||||
|
||||
|
||||
def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
|
||||
num_carry, linear):
|
||||
num_carry, linear, unroll):
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped}
|
||||
orig_batched = [d is not batching.not_mapped for d in dims]
|
||||
@ -1546,14 +1700,15 @@ def _scan_batching_rule(args, dims, reverse, length, jaxpr, num_consts,
|
||||
else x for x, d in zip(xs, xs_bdims)]
|
||||
new_args = new_consts + new_init + new_xs
|
||||
|
||||
outs = scan_p.bind(*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear)
|
||||
outs = scan_p.bind(
|
||||
*new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched,
|
||||
num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll)
|
||||
carry_bdims = [0 if b else batching.not_mapped for b in carry_batched]
|
||||
ys_bdims = [1 if b else batching.not_mapped for b in ys_batched]
|
||||
return outs, carry_bdims + ys_bdims
|
||||
|
||||
def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
|
||||
jaxpr, num_consts, num_carry, linear):
|
||||
jaxpr, num_consts, num_carry, linear, unroll):
|
||||
dynamic_length, = masking.shape_as_value((length,))
|
||||
masked_jaxpr = _masked_scan_jaxpr(jaxpr, num_consts, num_carry)
|
||||
consts, init, xs = split_list(padded_vals, [num_consts, num_carry])
|
||||
@ -1563,7 +1718,8 @@ def _scan_masking_rule(padded_vals, logical_shapes, reverse, length,
|
||||
*itertools.chain([dynamic_length] + consts, [0], init, xs),
|
||||
reverse=reverse, length=max_length, jaxpr=masked_jaxpr,
|
||||
num_consts=1 + num_consts, num_carry=1 + num_carry,
|
||||
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear))
|
||||
linear=tuple([False] + const_linear + [False] + init_linear + xs_linear),
|
||||
unroll=unroll)
|
||||
return out_vals[1:]
|
||||
|
||||
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
||||
@ -1584,10 +1740,16 @@ def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
|
||||
return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
|
||||
|
||||
def _scan_typecheck(*avals, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear):
|
||||
linear, unroll):
|
||||
core.typecheck_assert(
|
||||
len(linear) == len(avals),
|
||||
f'scan called with {len(linear)} linear flags for {len(avals)} operands')
|
||||
core.typecheck_assert(
|
||||
isinstance(unroll, int),
|
||||
f'unroll length {unroll} is not an int')
|
||||
core.typecheck_assert(
|
||||
unroll > 0,
|
||||
f'non-positive unroll length {unroll}')
|
||||
|
||||
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
|
||||
const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list(
|
||||
|
@ -60,6 +60,12 @@ COND_IMPLS = [
|
||||
]
|
||||
|
||||
|
||||
SCAN_IMPLS = [
|
||||
(lax.scan, 'unroll1'),
|
||||
(partial(lax.scan, unroll=2), 'unroll2'),
|
||||
]
|
||||
|
||||
|
||||
def while_loop_reference(cond, body, carry):
|
||||
while cond(carry):
|
||||
carry = body(carry)
|
||||
@ -1278,11 +1284,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, (7, 10))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f}
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
|
||||
jit_scan, jit_f, scan_name),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
||||
for jit_scan in [False, True]
|
||||
for jit_f in [False, True])
|
||||
def testScanImpl(self, jit_scan, jit_f):
|
||||
for jit_f in [False, True]
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def testScanImpl(self, jit_scan, jit_f, scan):
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
d = rng.randn(2)
|
||||
@ -1297,9 +1305,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
if jit_f:
|
||||
f = api.jit(f)
|
||||
if jit_scan:
|
||||
scan = api.jit(lax.scan, (0,))
|
||||
else:
|
||||
scan = lax.scan
|
||||
scan = api.jit(scan, (0,))
|
||||
|
||||
as_ = rng.randn(5, 3)
|
||||
c = rng.randn(4)
|
||||
@ -1309,11 +1315,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f}
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
|
||||
jit_scan, jit_f, scan_name),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
||||
for jit_scan in [False, True]
|
||||
for jit_f in [False, True])
|
||||
def testScanJVP(self, jit_scan, jit_f):
|
||||
for jit_f in [False, True]
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def testScanJVP(self, jit_scan, jit_f, scan):
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
d = rng.randn(2)
|
||||
@ -1328,9 +1336,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
if jit_f:
|
||||
f = api.jit(f)
|
||||
if jit_scan:
|
||||
scan = api.jit(lax.scan, (0,))
|
||||
else:
|
||||
scan = lax.scan
|
||||
scan = api.jit(scan, (0,))
|
||||
|
||||
as_ = rng.randn(5, 3)
|
||||
c = rng.randn(4)
|
||||
@ -1343,11 +1349,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f}
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
|
||||
jit_scan, jit_f, scan_name),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
||||
for jit_scan in [False, True]
|
||||
for jit_f in [False, True])
|
||||
def testScanLinearize(self, jit_scan, jit_f):
|
||||
for jit_f in [False, True]
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def testScanLinearize(self, jit_scan, jit_f, scan):
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
d = rng.randn(2)
|
||||
@ -1362,9 +1370,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
if jit_f:
|
||||
f = api.jit(f)
|
||||
if jit_scan:
|
||||
scan = api.jit(lax.scan, (0,))
|
||||
else:
|
||||
scan = lax.scan
|
||||
scan = api.jit(scan, (0,))
|
||||
|
||||
as_ = rng.randn(5, 3)
|
||||
c = rng.randn(4)
|
||||
@ -1375,12 +1381,14 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
rtol={np.float64: 1e-14})
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}".format(jit_scan, jit_f),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f}
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}_impl={}".format(
|
||||
jit_scan, jit_f, scan_name),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
|
||||
for jit_scan in [False, True]
|
||||
for jit_f in [False, True])
|
||||
for jit_f in [False, True]
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
||||
def testScanGrad(self, jit_scan, jit_f):
|
||||
def testScanGrad(self, jit_scan, jit_f, scan):
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
d = rng.randn(2)
|
||||
@ -1395,9 +1403,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
if jit_f:
|
||||
f = api.jit(f)
|
||||
if jit_scan:
|
||||
scan = api.jit(lax.scan, static_argnums=(0,))
|
||||
else:
|
||||
scan = lax.scan
|
||||
scan = api.jit(scan, static_argnums=(0,))
|
||||
|
||||
as_ = rng.randn(5, 3)
|
||||
c = rng.randn(4)
|
||||
@ -1543,7 +1549,11 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
lax.scan(lambda c, x: (0, x), (1, 2), jnp.arange(5))
|
||||
|
||||
|
||||
def testScanHigherOrderDifferentiation(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_{}".format(scan_name),
|
||||
"scan": scan_impl}
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def testScanHigherOrderDifferentiation(self, scan):
|
||||
d = 0.75
|
||||
def f(c, a):
|
||||
b = jnp.sin(c * jnp.sum(jnp.cos(d * a)))
|
||||
@ -1553,18 +1563,20 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
as_ = jnp.arange(6.).reshape((3, 2))
|
||||
c = 1.
|
||||
|
||||
jtu.check_grads(lambda c, as_: lax.scan(f, c, as_), (c, as_),
|
||||
jtu.check_grads(lambda c, as_: scan(f, c, as_), (c, as_),
|
||||
modes=["rev"], order=2, rtol={np.float32: 6e-3})
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}_in_axes={}".format(
|
||||
jit_scan, jit_f, in_axes),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f, "in_axes": in_axes}
|
||||
{"testcase_name": "_jit_scan={}_jit_f={}_in_axes={}_impl={}".format(
|
||||
jit_scan, jit_f, in_axes, scan_name),
|
||||
"jit_scan": jit_scan, "jit_f": jit_f, "in_axes": in_axes,
|
||||
"scan": scan_impl}
|
||||
for jit_scan in [False, True]
|
||||
for jit_f in [False, True]
|
||||
for scan_impl, scan_name in SCAN_IMPLS
|
||||
for in_axes in itertools.product([None, 0, 1], [None, 0, 1, 2])
|
||||
if in_axes != (None, None))
|
||||
def testScanVmap(self, jit_scan, jit_f, in_axes):
|
||||
def testScanVmap(self, jit_scan, jit_f, in_axes, scan):
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
d = rng.randn(2)
|
||||
@ -1579,9 +1591,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
if jit_f:
|
||||
f = api.jit(f)
|
||||
if jit_scan:
|
||||
scan = api.jit(lax.scan, (0,))
|
||||
else:
|
||||
scan = lax.scan
|
||||
scan = api.jit(scan, (0,))
|
||||
|
||||
as_shape = [5, 3]
|
||||
c_shape = [4]
|
||||
@ -2298,9 +2308,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
.format(too_big, api.device_count(), jtu.device_under_test())),
|
||||
lambda: f_loop(jnp.ones(too_big)))
|
||||
|
||||
def test_scan_reverse(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_{}".format(scan_name),
|
||||
"scan": scan_impl}
|
||||
for scan_impl, scan_name in SCAN_IMPLS)
|
||||
def test_scan_reverse(self, scan):
|
||||
def cumsum(x, reverse):
|
||||
return lax.scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
|
||||
return scan(lambda c, x: (c + x, c + x), 0, x, reverse=reverse)[1]
|
||||
|
||||
x = np.array([3, 1, 4, 1, 5, 9])
|
||||
self.assertAllClose(np.cumsum(x), cumsum(x, False), check_dtypes=False)
|
||||
@ -2311,6 +2325,32 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
with api.disable_jit():
|
||||
self.assertAllClose(np.cumsum(x[::-1])[::-1], cumsum(x, True), check_dtypes=False)
|
||||
|
||||
def test_scan_unroll(self):
|
||||
d = jnp.ones(2)
|
||||
def f(c, a):
|
||||
assert a.shape == (3,)
|
||||
assert c.shape == (4,)
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
|
||||
c = jnp.sin(c * b)
|
||||
assert b.shape == ()
|
||||
return c, b
|
||||
|
||||
xs = jnp.ones((5, 3))
|
||||
c = jnp.ones(4)
|
||||
|
||||
scan = lambda c, xs: lax.scan(f, c, xs)
|
||||
scan_unrolled = lambda c, xs: lax.scan(f, c, xs, unroll=2)
|
||||
|
||||
# jaxprs should be the same size
|
||||
self.assertEqual(
|
||||
len(str(api.make_jaxpr(scan)(c, xs))),
|
||||
len(str(api.make_jaxpr(scan_unrolled)(c, xs))))
|
||||
|
||||
# but HLO should grow due to unrolling
|
||||
self.assertLess(
|
||||
len(str(api.xla_computation(scan)(c, xs).as_hlo_text())),
|
||||
len(str(api.xla_computation(scan_unrolled)(c, xs).as_hlo_text())))
|
||||
|
||||
def test_disable_jit_cond_with_vmap(self):
|
||||
# https://github.com/google/jax/issues/3093
|
||||
def fn(t):
|
||||
|
Loading…
x
Reference in New Issue
Block a user