Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.

Add support for an axis= parameter to associative_scan.

We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.

lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.

The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.

We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.

Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s

After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s

Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s

After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
This commit is contained in:
Peter Hawkins 2020-10-15 20:26:29 -04:00
parent 7f4e115a6a
commit d3db7bd4be
6 changed files with 198 additions and 240 deletions

View File

@ -57,6 +57,10 @@ Operators
conv_transpose
cos
cosh
cummax
cummin
cumprod
cumsum
digamma
div
dot

View File

@ -29,6 +29,7 @@ import jax.linear_util as lu
from jax.interpreters import xla
from jax.custom_derivatives import custom_jvp_call_jaxpr_p
from jax.lax import lax
from jax.lax import lax_control_flow
from jax.lax import lax_fft
def jet(fun, primals, series):
@ -238,19 +239,20 @@ deflinear(lax_fft.fft_p)
deflinear(xla.device_put_p)
def _cumulative_jet_rule(primals_in, series_in, *, axis: int,
prefix_scan: Callable):
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return jet(partial(prefix_scan, axis=axis), primals_in, series_in)
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis),
primals_in, series_in)
deflinear(lax.cumsum_p)
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
prefix_scan=lax._cumprod_prefix_scan)
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
prefix_scan=lax._cummax_prefix_scan)
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
prefix_scan=lax._cummin_prefix_scan)
deflinear(lax_control_flow.cumsum_p)
jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule,
combine_fn=lax.mul)
jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule,
combine_fn=lax.max)
jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule,
combine_fn=lax.min)
def def_deriv(prim, deriv):

View File

@ -94,14 +94,6 @@ from .lax import (
cosh_p,
create_token,
create_token_p,
cummax,
cummax_p,
cummin,
cummin_p,
cumprod,
cumprod_p,
cumsum,
cumsum_p,
digamma,
digamma_p,
div,
@ -299,8 +291,17 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_upcast_fp16_for_computation, _broadcasting_shape_rule,
_eye, _tri, _delta, _ones, _zeros)
from .lax_control_flow import (
associative_scan,
cond,
cond_p,
cummax,
cummax_p,
cummin,
cummin_p,
cumprod,
cumprod_p,
cumsum,
cumsum_p,
custom_linear_solve,
custom_root,
fori_loop,
@ -312,7 +313,6 @@ from .lax_control_flow import (
switch,
while_loop,
while_p,
associative_scan,
)
from .lax_fft import (
fft,

View File

@ -1330,22 +1330,6 @@ def _select_and_gather_add(tangents: Array, operand: Array,
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))
def cumsum(operand: Array, axis: int) -> Array:
"""Computes a cumulative sum along `axis`."""
return cumsum_p.bind(operand, axis=int(axis))
def cumprod(operand: Array, axis: int) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis))
def cummax(operand: Array, axis: int) -> Array:
"""Computes a cumulative maximum along `axis`."""
return cummax_p.bind(operand, axis=int(axis))
def cummin(operand: Array, axis: int) -> Array:
"""Computes a cumulative minimum along `axis`."""
return cummin_p.bind(operand, axis=int(axis))
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
"""Wraps XLA's `Sort
@ -5391,145 +5375,6 @@ xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
_select_and_gather_add_translation,
max_bits=32)
# Parallel prefix-scan. See:
# https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
# and
# Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report
# CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.
#
# Unlike the Blelloch algorithm, we use an out-of-place algorithm that uses 2n
# space. This is somewhat wasteful if we are interested only in the output of
# the forward pass, but more memory-efficient if we intend to differentiate
# through the implementation of the scan.
def _prescan_power_of_two(x, axis: int, op: Callable, unit):
n = x.shape[axis]
assert n != 0 and n & (n - 1) == 0, "n must be a power of 2"
# Upsweep
xs = []
for d in range(0, n.bit_length() - 1):
x1 = slice_in_dim(x, 0, None, stride=2, axis=axis)
xs.append(x1)
x2 = slice_in_dim(x, 1, None, stride=2, axis=axis)
x = op(x1, x2)
total = x
# Downsweep
x = full_like(total, unit)
pad_left = [(0, 0, 0)] * len(x.shape)
pad_left[axis] = (1, 0, 1)
pad_right = [(0, 0, 0)] * len(x.shape)
pad_right[axis] = (0, 1, 1)
for w in reversed(xs):
x1 = pad(x, _const(x, 0), pad_right)
x2 = pad(x, _const(x, 0), pad_left)
w = pad(w, _const(x, 0), pad_left)
x = x1 + op(x2, w)
return x, total
def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any):
if np.issubdtype(x.dtype, np.integer):
if np.isposinf(unit):
unit = np.iinfo(x.dtype).max
elif np.isneginf(unit):
unit = np.iinfo(x.dtype).min
n = x.shape[axis]
if n == 0:
return x
# Pads to the next largest power of two
nbits = n.bit_length()
if n == (1 << (nbits - 1)):
nbits -= 1
padding = [(0, 0, 0)] * len(x.shape)
padding[axis] = (0, (1 << nbits) - n, 0)
x = pad(x, _const(x, unit), padding)
x, total = _prescan_power_of_two(x, axis, op, unit)
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-np.inf)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=np.inf)
def _cumred_shape_rule(x, *, axis: int):
if axis < 0 or axis >= x.ndim:
raise ValueError(
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
return x.shape
def _cumsum_transpose_rule(t, *, axis: int):
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
prefix_scan: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(prefix_scan, axis=axis), primals, tangents)
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
axis: int):
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0)] * x.ndim
padding[axis] = (n - 1, 0)
strides = [1] * x.ndim
window_dims = [1] * x.ndim
window_dims[axis] = n
return window_reduce(x, window_dims, strides, padding)
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
operand, = batched_args
bdim, = batch_dims
axis = axis if axis < bdim else axis + 1
return prim.bind(operand, axis=axis), bdim
cumsum_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumsum"),
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_sum),
multiple_results=False)
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
def _cumulative_reduction_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
reducer_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, name),
name, xla.lower_fun(prefix_scan_fn, multiple_results=False))
ad.primitive_jvps[reducer_p] = jvp_rule
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, reduce_window_fn),
multiple_results=False)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
return reducer_p
cumprod_p = _cumulative_reduction_primitive("cumprod", _cumprod_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cumprod_prefix_scan),
_reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", _cummax_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cummax_prefix_scan),
_reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", _cummin_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cummin_prefix_scan),
_reduce_window_min)
def _sort_abstract_eval(*args, **kwargs):
args = tuple(raise_to_shaped(arg) for arg in args)
if any(arg.shape != args[0].shape for arg in args[1:]):

View File

@ -23,11 +23,12 @@ import inspect
import itertools
import operator
import os
from typing import Callable, Sequence, TypeVar
from typing import Any, Callable, Sequence, TypeVar
import numpy as np
import jax
from jax import api
from jax import core
from jax import dtypes
from jax import source_info_util
@ -58,7 +59,7 @@ zip = safe_zip
_reduce = functools.reduce
T = TypeVar('T')
Array = Any
@cache()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
@ -2348,25 +2349,21 @@ ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
def _interleave(a, b):
def _interleave(a, b, axis):
"""Given two Tensors of static shape, interleave them along the first axis."""
# TODO(mattjj)
import jax.numpy as jnp
# [a b c ...] [d e f ...] -> [a d b e c f ...]
half_num_elems = b.shape[0]
assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
a_pad = [(0, 0, 0)] * a.ndim
b_pad = [(0, 0, 0)] * b.ndim
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
lax.pad(b, lax._const(b, 0), b_pad))
if a.shape[0] > b.shape[0]:
return jnp.concatenate(
[jnp.reshape(jnp.stack([a[: -1], b], axis=1),
(2 * half_num_elems,) + a.shape[1:]),
a[-1:]], axis=0)
else:
return jnp.reshape(jnp.stack([a, b], axis=1),
(2 * half_num_elems,) + a.shape[1:])
def associative_scan(fn: Callable, elems, reverse: bool = False):
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
"""Performs a scan with an associative binary operation, in parallel.
For an introduction to associative scans, see [BLE1990]_.
Args:
fn: A Python callable implementing an associative binary operation with
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
@ -2374,24 +2371,25 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
The inputs and result are (possibly nested Python tree structures of)
array(s) matching ``elems``. Each array has a leading dimension in place
of the ``num_elems`` dimension. `fn` should be applied elementwise over
the leading dimension (for example, by using :func:`jax.vmap` over the
array(s) matching ``elems``. Each array has a dimension in place
of the ``axis`` dimension. `fn` should be applied elementwise over
the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
elementwise function.)
The result `r` has the same shape (and structure) as the two inputs ``a``
and ``b``.
elems: A (possibly nested structure of) array(s), each with leading
dimension ``num_elems``.
The result ``r`` has the same shape (and structure) as the two inputs
``a`` and ``b``.
elems: A (possibly nested Python tree structure of) array(s), each with
an ``axis`` dimension of size ``num_elems``.
reverse: A boolean stating if the scan should be reversed with respect to
the leading dimension.
the ``axis`` dimension.
axis: an integer identifying the axis over which the scan should occur.
Returns:
A (possibly nested Python tree structure of) array(s) of the same shape
and structure as ``elems``, in which the ``k``'th element is the result of
recursively applying ``fn`` to combine the first ``k`` elements of
``elems``. For example, given ``elems = [a, b, c, ...]``, the result
would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
result of recursively applying ``fn`` to combine the first ``k`` elements
of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
Example 1: partial sums of an array of numbers:
@ -2409,13 +2407,17 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
[ 6, 6, 5, 3]
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
University.
"""
elems_flat, tree = tree_flatten(elems)
if reverse:
elems_flat = [lax.rev(elem, [0]) for elem in elems_flat]
elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
def lowered_fn(a_flat, b_flat):
def combine(a_flat, b_flat):
# Lower `fn` to operate on flattened sequences of elems.
a = tree_unflatten(tree, a_flat)
b = tree_unflatten(tree, b_flat)
@ -2424,14 +2426,13 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
return c_flat
# Check that all inputs have a consistent leading dimension `num_elems`.
num_elems = int(elems_flat[0].shape[0])
axis = lax._canonicalize_axis(axis, elems_flat[0].ndim)
num_elems = int(elems_flat[0].shape[axis])
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Array inputs to associative_scan must have the same '
'first dimension. (saw: {})'
.format([elems.shape for elem in elems_flat]))
if not all(int(elem.shape[0]) == num_elems for elem in elems_flat[1:]):
raise ValueError('Input `Tensor`s must have the same first dimension.'
' (saw: {})'.format([elems.shape for elem in elems_flat]))
if num_elems < 2:
return elems
# Summary of algorithm:
#
@ -2451,49 +2452,146 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
def _scan(elems):
"""Perform scan on `elems`."""
num_elems = elems[0].shape[0]
num_elems = elems[0].shape[axis]
reduced_elems = lowered_fn([elem[0:-1:2] for elem in elems],
[elem[1::2] for elem in elems])
if num_elems < 2:
return elems
if reduced_elems[0].shape[0] == 1:
# Base case has either 2 or 3 elements.
if num_elems == 2:
return [lax.concatenate([elem[0:1], reduced_elem], dimension=0)
for (reduced_elem, elem) in zip(reduced_elems, elems)]
elif num_elems == 3:
reduced_reduced_elems = lowered_fn(
reduced_elems,
[elem[2:3] for elem in elems])
return [
lax.concatenate([elem[0:1], reduced_elem, reduced_reduced_elem],
dimension=0)
for (reduced_reduced_elem, reduced_elem, elem)
in zip(reduced_reduced_elems, reduced_elems, elems)]
# Combine adjacent pairs of elements.
reduced_elems = combine(
[lax.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems],
[lax.slice_in_dim(elem, 1, None, stride=2, axis=axis) for elem in elems])
# Recursively compute scan for partially reduced tensors.
odd_elems = _scan(reduced_elems)
if num_elems % 2 == 0:
results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
[elem[2::2] for elem in elems])
even_elems = combine(
[lax.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems],
[lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
else:
results = lowered_fn(list(odd_elems), [elem[2::2] for elem in elems])
even_elems = combine(
odd_elems,
[lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
# The first element of a scan is the same as the first element
# of the original `elems`.
even_elems = [lax.concatenate([elem[0:1], result], dimension=0)
for (elem, result) in zip(elems, results)]
return tuple(_map(_interleave, even_elems, odd_elems))
even_elems = [
lax.concatenate([lax.slice_in_dim(elem, 0, 1, axis=axis), result],
dimension=axis)
for (elem, result) in zip(elems, even_elems)]
return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems))
scans = _scan(elems_flat)
if reverse:
scans = [lax.rev(scanned, [0]) for scanned in scans]
scans = [lax.rev(scanned, [axis]) for scanned in scans]
return tree_unflatten(tree, scans)
# Cumulative reductions.
def cumsum(operand: Array, axis: int) -> Array:
"""Computes a cumulative sum along `axis`."""
return cumsum_p.bind(operand, axis=int(axis))
def cumprod(operand: Array, axis: int) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis))
def cummax(operand: Array, axis: int) -> Array:
"""Computes a cumulative maximum along `axis`."""
return cummax_p.bind(operand, axis=int(axis))
def cummin(operand: Array, axis: int) -> Array:
"""Computes a cumulative minimum along `axis`."""
return cummin_p.bind(operand, axis=int(axis))
def _cumred_shape_rule(x, *, axis: int):
if axis < 0 or axis >= x.ndim:
raise ValueError(
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
return x.shape
def _cumsum_transpose_rule(t, *, axis: int):
return [lax.rev(cumsum(lax.rev(t, (axis,)), axis=axis), (axis,))]
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
axis: int):
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0)] * x.ndim
padding[axis] = (n - 1, 0)
strides = [1] * x.ndim
window_dims = [1] * x.ndim
window_dims[axis] = n
return window_reduce(x, window_dims, strides, padding)
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
operand, = batched_args
bdim, = batch_dims
axis = axis if axis < bdim else axis + 1
return prim.bind(operand, axis=axis), bdim
def _cumred_dtype_rule(name, operand, *args, **kw):
if not dtypes.issubdtype(operand.dtype, np.number):
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
"of number.".format(name, np.dtype(operand.dtype).name))
return dtypes.canonicalize_dtype(operand.dtype)
cumsum_p = lax.standard_primitive(
_cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"),
'cumsum')
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, lax._reduce_window_sum),
multiple_results=False)
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
def _cumulative_reduction_primitive(name, reduce_window_fn):
reducer_p = lax.standard_primitive(
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
name)
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, reduce_window_fn),
multiple_results=False)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
return reducer_p
cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min)
xla.translations[cumsum_p] = xla.lower_fun(
partial(associative_scan, lax.add), multiple_results=False)
xla.translations[cumprod_p] = xla.lower_fun(
partial(associative_scan, lax.mul), multiple_results=False)
xla.translations[cummin_p] = xla.lower_fun(
partial(associative_scan, lax.min), multiple_results=False)
xla.translations[cummax_p] = xla.lower_fun(
partial(associative_scan, lax.max), multiple_results=False)
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(associative_scan, combine_fn, axis=axis),
primals, tangents)
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_open_jaxpr, _initial_style_jaxpr, \

View File

@ -2375,14 +2375,23 @@ class LaxControlFlowTest(jtu.JaxTestCase):
x, n = jnp.arange(3), jnp.arange(4)
api.vmap(api.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash
def testAssociativeScanUnstructured1000(self):
data = np.arange(1000)
expected = np.cumsum(data)
result = lax.associative_scan(operator.add, data)
@parameterized.named_parameters(
{"testcase_name": f"_{shape}_axis={axis}",
"shape": shape, "axis": axis}
for shape in [
[0], [1], [2], [3], [5], [10], [1000],
[2, 3], [7, 5], [5, 6, 7]
]
for axis in range(-len(shape), len(shape) - 1))
def testAssociativeScanUnstructured(self, shape, axis):
data = np.arange(np.prod(shape)).reshape(shape) + 7
expected = np.cumsum(data, axis=axis)
result = lax.associative_scan(operator.add, data, axis=axis)
self.assertAllClose(result, expected, check_dtypes=False)
def testAssociativeScanUnstructured1000Reverse(self):
data = np.arange(1000)
data = np.arange(1000) + 32
expected = np.cumsum(data[::-1])[::-1]
result = lax.associative_scan(operator.add, data, reverse=True)
self.assertAllClose(result, expected, check_dtypes=False)