mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan. We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc. lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135. The cumsum/cummax implementation used slightly more efficient code to slice and interleave arrays, which this change adds to associative_scan as well. Since we are now using lax primitives that make it easy to select an axis, add support for user-chosen scan axes as well. We can also simplify the implementation of associative_scan: one of the recursive base cases seems unnecessary, and we can simplify the code by removing it. Benchmarks from #4135 on my workstation: Before: bench_cumsum: 0.900s bench_associative_scan: 0.597s bench_scan: 0.359s bench_np: 1.619s After: bench_cumsum: 0.435s bench_associative_scan: 0.435s bench_scan: 0.362s bench_np: 1.669s Before, with taskset -c 0: bench_cumsum: 1.989s bench_associative_scan: 1.556s bench_scan: 0.428s bench_np: 1.670s After, with taskset -c 0: bench_cumsum: 1.271s bench_associative_scan: 1.275s bench_scan: 0.438s bench_np: 1.673s
This commit is contained in:
parent
7f4e115a6a
commit
d3db7bd4be
@ -57,6 +57,10 @@ Operators
|
||||
conv_transpose
|
||||
cos
|
||||
cosh
|
||||
cummax
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
digamma
|
||||
div
|
||||
dot
|
||||
|
@ -29,6 +29,7 @@ import jax.linear_util as lu
|
||||
from jax.interpreters import xla
|
||||
from jax.custom_derivatives import custom_jvp_call_jaxpr_p
|
||||
from jax.lax import lax
|
||||
from jax.lax import lax_control_flow
|
||||
from jax.lax import lax_fft
|
||||
|
||||
def jet(fun, primals, series):
|
||||
@ -238,19 +239,20 @@ deflinear(lax_fft.fft_p)
|
||||
deflinear(xla.device_put_p)
|
||||
|
||||
def _cumulative_jet_rule(primals_in, series_in, *, axis: int,
|
||||
prefix_scan: Callable):
|
||||
combine_fn: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return jet(partial(prefix_scan, axis=axis), primals_in, series_in)
|
||||
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis),
|
||||
primals_in, series_in)
|
||||
|
||||
deflinear(lax.cumsum_p)
|
||||
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
|
||||
prefix_scan=lax._cumprod_prefix_scan)
|
||||
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
|
||||
prefix_scan=lax._cummax_prefix_scan)
|
||||
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
|
||||
prefix_scan=lax._cummin_prefix_scan)
|
||||
deflinear(lax_control_flow.cumsum_p)
|
||||
jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule,
|
||||
combine_fn=lax.mul)
|
||||
jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule,
|
||||
combine_fn=lax.max)
|
||||
jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule,
|
||||
combine_fn=lax.min)
|
||||
|
||||
|
||||
def def_deriv(prim, deriv):
|
||||
|
@ -94,14 +94,6 @@ from .lax import (
|
||||
cosh_p,
|
||||
create_token,
|
||||
create_token_p,
|
||||
cummax,
|
||||
cummax_p,
|
||||
cummin,
|
||||
cummin_p,
|
||||
cumprod,
|
||||
cumprod_p,
|
||||
cumsum,
|
||||
cumsum_p,
|
||||
digamma,
|
||||
digamma_p,
|
||||
div,
|
||||
@ -299,8 +291,17 @@ from .lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
|
||||
_upcast_fp16_for_computation, _broadcasting_shape_rule,
|
||||
_eye, _tri, _delta, _ones, _zeros)
|
||||
from .lax_control_flow import (
|
||||
associative_scan,
|
||||
cond,
|
||||
cond_p,
|
||||
cummax,
|
||||
cummax_p,
|
||||
cummin,
|
||||
cummin_p,
|
||||
cumprod,
|
||||
cumprod_p,
|
||||
cumsum,
|
||||
cumsum_p,
|
||||
custom_linear_solve,
|
||||
custom_root,
|
||||
fori_loop,
|
||||
@ -312,7 +313,6 @@ from .lax_control_flow import (
|
||||
switch,
|
||||
while_loop,
|
||||
while_p,
|
||||
associative_scan,
|
||||
)
|
||||
from .lax_fft import (
|
||||
fft,
|
||||
|
155
jax/lax/lax.py
155
jax/lax/lax.py
@ -1330,22 +1330,6 @@ def _select_and_gather_add(tangents: Array, operand: Array,
|
||||
base_dilation=tuple(base_dilation),
|
||||
window_dilation=tuple(window_dilation))
|
||||
|
||||
def cumsum(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative sum along `axis`."""
|
||||
return cumsum_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cumprod(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative product along `axis`."""
|
||||
return cumprod_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cummax(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative maximum along `axis`."""
|
||||
return cummax_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cummin(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative minimum along `axis`."""
|
||||
return cummin_p.bind(operand, axis=int(axis))
|
||||
|
||||
def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
|
||||
"""Wraps XLA's `Sort
|
||||
@ -5391,145 +5375,6 @@ xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
|
||||
_select_and_gather_add_translation,
|
||||
max_bits=32)
|
||||
|
||||
|
||||
# Parallel prefix-scan. See:
|
||||
# https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
|
||||
# and
|
||||
# Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report
|
||||
# CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.
|
||||
#
|
||||
# Unlike the Blelloch algorithm, we use an out-of-place algorithm that uses 2n
|
||||
# space. This is somewhat wasteful if we are interested only in the output of
|
||||
# the forward pass, but more memory-efficient if we intend to differentiate
|
||||
# through the implementation of the scan.
|
||||
def _prescan_power_of_two(x, axis: int, op: Callable, unit):
|
||||
n = x.shape[axis]
|
||||
assert n != 0 and n & (n - 1) == 0, "n must be a power of 2"
|
||||
|
||||
# Upsweep
|
||||
xs = []
|
||||
for d in range(0, n.bit_length() - 1):
|
||||
x1 = slice_in_dim(x, 0, None, stride=2, axis=axis)
|
||||
xs.append(x1)
|
||||
x2 = slice_in_dim(x, 1, None, stride=2, axis=axis)
|
||||
x = op(x1, x2)
|
||||
total = x
|
||||
|
||||
# Downsweep
|
||||
x = full_like(total, unit)
|
||||
pad_left = [(0, 0, 0)] * len(x.shape)
|
||||
pad_left[axis] = (1, 0, 1)
|
||||
pad_right = [(0, 0, 0)] * len(x.shape)
|
||||
pad_right[axis] = (0, 1, 1)
|
||||
for w in reversed(xs):
|
||||
x1 = pad(x, _const(x, 0), pad_right)
|
||||
x2 = pad(x, _const(x, 0), pad_left)
|
||||
w = pad(w, _const(x, 0), pad_left)
|
||||
x = x1 + op(x2, w)
|
||||
|
||||
return x, total
|
||||
|
||||
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any):
|
||||
if np.issubdtype(x.dtype, np.integer):
|
||||
if np.isposinf(unit):
|
||||
unit = np.iinfo(x.dtype).max
|
||||
elif np.isneginf(unit):
|
||||
unit = np.iinfo(x.dtype).min
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
# Pads to the next largest power of two
|
||||
nbits = n.bit_length()
|
||||
if n == (1 << (nbits - 1)):
|
||||
nbits -= 1
|
||||
padding = [(0, 0, 0)] * len(x.shape)
|
||||
padding[axis] = (0, (1 << nbits) - n, 0)
|
||||
x = pad(x, _const(x, unit), padding)
|
||||
x, total = _prescan_power_of_two(x, axis, op, unit)
|
||||
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
|
||||
|
||||
_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-np.inf)
|
||||
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=np.inf)
|
||||
|
||||
def _cumred_shape_rule(x, *, axis: int):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
||||
return x.shape
|
||||
|
||||
def _cumsum_transpose_rule(t, *, axis: int):
|
||||
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]
|
||||
|
||||
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
|
||||
prefix_scan: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(prefix_scan, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
|
||||
axis: int):
|
||||
# On TPU, an implementation using reduce_window is handled specially by the
|
||||
# compiler and is efficient. On other backends, it is O(n^2).
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
padding = [(0, 0)] * x.ndim
|
||||
padding[axis] = (n - 1, 0)
|
||||
strides = [1] * x.ndim
|
||||
window_dims = [1] * x.ndim
|
||||
window_dims[axis] = n
|
||||
return window_reduce(x, window_dims, strides, padding)
|
||||
|
||||
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
axis = axis if axis < bdim else axis + 1
|
||||
return prim.bind(operand, axis=axis), bdim
|
||||
|
||||
|
||||
cumsum_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumsum"),
|
||||
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
|
||||
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
|
||||
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_sum),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
||||
|
||||
|
||||
def _cumulative_reduction_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
|
||||
reducer_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, name),
|
||||
name, xla.lower_fun(prefix_scan_fn, multiple_results=False))
|
||||
ad.primitive_jvps[reducer_p] = jvp_rule
|
||||
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, reduce_window_fn),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
|
||||
return reducer_p
|
||||
|
||||
|
||||
cumprod_p = _cumulative_reduction_primitive("cumprod", _cumprod_prefix_scan,
|
||||
partial(_cumulative_jvp_rule,
|
||||
prefix_scan=_cumprod_prefix_scan),
|
||||
_reduce_window_prod)
|
||||
|
||||
cummax_p = _cumulative_reduction_primitive("cummax", _cummax_prefix_scan,
|
||||
partial(_cumulative_jvp_rule,
|
||||
prefix_scan=_cummax_prefix_scan),
|
||||
_reduce_window_max)
|
||||
|
||||
cummin_p = _cumulative_reduction_primitive("cummin", _cummin_prefix_scan,
|
||||
partial(_cumulative_jvp_rule,
|
||||
prefix_scan=_cummin_prefix_scan),
|
||||
_reduce_window_min)
|
||||
|
||||
|
||||
def _sort_abstract_eval(*args, **kwargs):
|
||||
args = tuple(raise_to_shaped(arg) for arg in args)
|
||||
if any(arg.shape != args[0].shape for arg in args[1:]):
|
||||
|
@ -23,11 +23,12 @@ import inspect
|
||||
import itertools
|
||||
import operator
|
||||
import os
|
||||
from typing import Callable, Sequence, TypeVar
|
||||
from typing import Any, Callable, Sequence, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import source_info_util
|
||||
@ -58,7 +59,7 @@ zip = safe_zip
|
||||
_reduce = functools.reduce
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
Array = Any
|
||||
|
||||
@cache()
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
@ -2348,25 +2349,21 @@ ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
|
||||
batching.primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
|
||||
|
||||
|
||||
def _interleave(a, b):
|
||||
def _interleave(a, b, axis):
|
||||
"""Given two Tensors of static shape, interleave them along the first axis."""
|
||||
# TODO(mattjj)
|
||||
import jax.numpy as jnp
|
||||
# [a b c ...] [d e f ...] -> [a d b e c f ...]
|
||||
half_num_elems = b.shape[0]
|
||||
assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1
|
||||
a_pad = [(0, 0, 0)] * a.ndim
|
||||
b_pad = [(0, 0, 0)] * b.ndim
|
||||
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
|
||||
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
|
||||
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
|
||||
lax.pad(b, lax._const(b, 0), b_pad))
|
||||
|
||||
if a.shape[0] > b.shape[0]:
|
||||
return jnp.concatenate(
|
||||
[jnp.reshape(jnp.stack([a[: -1], b], axis=1),
|
||||
(2 * half_num_elems,) + a.shape[1:]),
|
||||
a[-1:]], axis=0)
|
||||
else:
|
||||
return jnp.reshape(jnp.stack([a, b], axis=1),
|
||||
(2 * half_num_elems,) + a.shape[1:])
|
||||
|
||||
def associative_scan(fn: Callable, elems, reverse: bool = False):
|
||||
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
"""Performs a scan with an associative binary operation, in parallel.
|
||||
|
||||
For an introduction to associative scans, see [BLE1990]_.
|
||||
|
||||
Args:
|
||||
fn: A Python callable implementing an associative binary operation with
|
||||
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
|
||||
@ -2374,24 +2371,25 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
|
||||
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
|
||||
|
||||
The inputs and result are (possibly nested Python tree structures of)
|
||||
array(s) matching ``elems``. Each array has a leading dimension in place
|
||||
of the ``num_elems`` dimension. `fn` should be applied elementwise over
|
||||
the leading dimension (for example, by using :func:`jax.vmap` over the
|
||||
array(s) matching ``elems``. Each array has a dimension in place
|
||||
of the ``axis`` dimension. `fn` should be applied elementwise over
|
||||
the ``axis`` dimension (for example, by using :func:`jax.vmap` over the
|
||||
elementwise function.)
|
||||
|
||||
The result `r` has the same shape (and structure) as the two inputs ``a``
|
||||
and ``b``.
|
||||
elems: A (possibly nested structure of) array(s), each with leading
|
||||
dimension ``num_elems``.
|
||||
The result ``r`` has the same shape (and structure) as the two inputs
|
||||
``a`` and ``b``.
|
||||
elems: A (possibly nested Python tree structure of) array(s), each with
|
||||
an ``axis`` dimension of size ``num_elems``.
|
||||
reverse: A boolean stating if the scan should be reversed with respect to
|
||||
the leading dimension.
|
||||
the ``axis`` dimension.
|
||||
axis: an integer identifying the axis over which the scan should occur.
|
||||
|
||||
Returns:
|
||||
A (possibly nested Python tree structure of) array(s) of the same shape
|
||||
and structure as ``elems``, in which the ``k``'th element is the result of
|
||||
recursively applying ``fn`` to combine the first ``k`` elements of
|
||||
``elems``. For example, given ``elems = [a, b, c, ...]``, the result
|
||||
would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
|
||||
and structure as ``elems``, in which the ``k``'th element of ``axis`` is the
|
||||
result of recursively applying ``fn`` to combine the first ``k`` elements
|
||||
of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``,
|
||||
the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
|
||||
|
||||
Example 1: partial sums of an array of numbers:
|
||||
|
||||
@ -2409,13 +2407,17 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
|
||||
|
||||
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
|
||||
[ 6, 6, 5, 3]
|
||||
|
||||
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
|
||||
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
|
||||
University.
|
||||
"""
|
||||
elems_flat, tree = tree_flatten(elems)
|
||||
|
||||
if reverse:
|
||||
elems_flat = [lax.rev(elem, [0]) for elem in elems_flat]
|
||||
elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat]
|
||||
|
||||
def lowered_fn(a_flat, b_flat):
|
||||
def combine(a_flat, b_flat):
|
||||
# Lower `fn` to operate on flattened sequences of elems.
|
||||
a = tree_unflatten(tree, a_flat)
|
||||
b = tree_unflatten(tree, b_flat)
|
||||
@ -2424,14 +2426,13 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
|
||||
return c_flat
|
||||
|
||||
# Check that all inputs have a consistent leading dimension `num_elems`.
|
||||
num_elems = int(elems_flat[0].shape[0])
|
||||
axis = lax._canonicalize_axis(axis, elems_flat[0].ndim)
|
||||
num_elems = int(elems_flat[0].shape[axis])
|
||||
if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):
|
||||
raise ValueError('Array inputs to associative_scan must have the same '
|
||||
'first dimension. (saw: {})'
|
||||
.format([elems.shape for elem in elems_flat]))
|
||||
|
||||
if not all(int(elem.shape[0]) == num_elems for elem in elems_flat[1:]):
|
||||
raise ValueError('Input `Tensor`s must have the same first dimension.'
|
||||
' (saw: {})'.format([elems.shape for elem in elems_flat]))
|
||||
|
||||
if num_elems < 2:
|
||||
return elems
|
||||
|
||||
# Summary of algorithm:
|
||||
#
|
||||
@ -2451,49 +2452,146 @@ def associative_scan(fn: Callable, elems, reverse: bool = False):
|
||||
def _scan(elems):
|
||||
"""Perform scan on `elems`."""
|
||||
|
||||
num_elems = elems[0].shape[0]
|
||||
num_elems = elems[0].shape[axis]
|
||||
|
||||
reduced_elems = lowered_fn([elem[0:-1:2] for elem in elems],
|
||||
[elem[1::2] for elem in elems])
|
||||
if num_elems < 2:
|
||||
return elems
|
||||
|
||||
if reduced_elems[0].shape[0] == 1:
|
||||
# Base case has either 2 or 3 elements.
|
||||
if num_elems == 2:
|
||||
return [lax.concatenate([elem[0:1], reduced_elem], dimension=0)
|
||||
for (reduced_elem, elem) in zip(reduced_elems, elems)]
|
||||
elif num_elems == 3:
|
||||
reduced_reduced_elems = lowered_fn(
|
||||
reduced_elems,
|
||||
[elem[2:3] for elem in elems])
|
||||
return [
|
||||
lax.concatenate([elem[0:1], reduced_elem, reduced_reduced_elem],
|
||||
dimension=0)
|
||||
for (reduced_reduced_elem, reduced_elem, elem)
|
||||
in zip(reduced_reduced_elems, reduced_elems, elems)]
|
||||
# Combine adjacent pairs of elements.
|
||||
reduced_elems = combine(
|
||||
[lax.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems],
|
||||
[lax.slice_in_dim(elem, 1, None, stride=2, axis=axis) for elem in elems])
|
||||
|
||||
# Recursively compute scan for partially reduced tensors.
|
||||
odd_elems = _scan(reduced_elems)
|
||||
|
||||
if num_elems % 2 == 0:
|
||||
results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
|
||||
[elem[2::2] for elem in elems])
|
||||
even_elems = combine(
|
||||
[lax.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems],
|
||||
[lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
|
||||
else:
|
||||
results = lowered_fn(list(odd_elems), [elem[2::2] for elem in elems])
|
||||
even_elems = combine(
|
||||
odd_elems,
|
||||
[lax.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems])
|
||||
|
||||
# The first element of a scan is the same as the first element
|
||||
# of the original `elems`.
|
||||
even_elems = [lax.concatenate([elem[0:1], result], dimension=0)
|
||||
for (elem, result) in zip(elems, results)]
|
||||
return tuple(_map(_interleave, even_elems, odd_elems))
|
||||
even_elems = [
|
||||
lax.concatenate([lax.slice_in_dim(elem, 0, 1, axis=axis), result],
|
||||
dimension=axis)
|
||||
for (elem, result) in zip(elems, even_elems)]
|
||||
return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems))
|
||||
|
||||
scans = _scan(elems_flat)
|
||||
|
||||
if reverse:
|
||||
scans = [lax.rev(scanned, [0]) for scanned in scans]
|
||||
scans = [lax.rev(scanned, [axis]) for scanned in scans]
|
||||
|
||||
return tree_unflatten(tree, scans)
|
||||
|
||||
|
||||
# Cumulative reductions.
|
||||
|
||||
def cumsum(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative sum along `axis`."""
|
||||
return cumsum_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cumprod(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative product along `axis`."""
|
||||
return cumprod_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cummax(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative maximum along `axis`."""
|
||||
return cummax_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cummin(operand: Array, axis: int) -> Array:
|
||||
"""Computes a cumulative minimum along `axis`."""
|
||||
return cummin_p.bind(operand, axis=int(axis))
|
||||
|
||||
def _cumred_shape_rule(x, *, axis: int):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
||||
return x.shape
|
||||
|
||||
def _cumsum_transpose_rule(t, *, axis: int):
|
||||
return [lax.rev(cumsum(lax.rev(t, (axis,)), axis=axis), (axis,))]
|
||||
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
|
||||
axis: int):
|
||||
# On TPU, an implementation using reduce_window is handled specially by the
|
||||
# compiler and is efficient. On other backends, it is O(n^2).
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
padding = [(0, 0)] * x.ndim
|
||||
padding[axis] = (n - 1, 0)
|
||||
strides = [1] * x.ndim
|
||||
window_dims = [1] * x.ndim
|
||||
window_dims[axis] = n
|
||||
return window_reduce(x, window_dims, strides, padding)
|
||||
|
||||
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
axis = axis if axis < bdim else axis + 1
|
||||
return prim.bind(operand, axis=axis), bdim
|
||||
|
||||
def _cumred_dtype_rule(name, operand, *args, **kw):
|
||||
if not dtypes.issubdtype(operand.dtype, np.number):
|
||||
raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes "
|
||||
"of number.".format(name, np.dtype(operand.dtype).name))
|
||||
return dtypes.canonicalize_dtype(operand.dtype)
|
||||
|
||||
cumsum_p = lax.standard_primitive(
|
||||
_cumred_shape_rule, partial(_cumred_dtype_rule, "cumsum"),
|
||||
'cumsum')
|
||||
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
|
||||
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, lax._reduce_window_sum),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
||||
|
||||
|
||||
def _cumulative_reduction_primitive(name, reduce_window_fn):
|
||||
reducer_p = lax.standard_primitive(
|
||||
_cumred_shape_rule, partial(_cumred_dtype_rule, name),
|
||||
name)
|
||||
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, reduce_window_fn),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
|
||||
return reducer_p
|
||||
|
||||
|
||||
cumprod_p = _cumulative_reduction_primitive("cumprod", lax._reduce_window_prod)
|
||||
cummax_p = _cumulative_reduction_primitive("cummax", lax._reduce_window_max)
|
||||
cummin_p = _cumulative_reduction_primitive("cummin", lax._reduce_window_min)
|
||||
|
||||
xla.translations[cumsum_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.add), multiple_results=False)
|
||||
xla.translations[cumprod_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.mul), multiple_results=False)
|
||||
xla.translations[cummin_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.min), multiple_results=False)
|
||||
xla.translations[cummax_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.max), multiple_results=False)
|
||||
|
||||
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
|
||||
combine_fn: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(associative_scan, combine_fn, axis=axis),
|
||||
primals, tangents)
|
||||
|
||||
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
|
||||
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
|
||||
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_open_jaxpr, _initial_style_jaxpr, \
|
||||
|
@ -2375,14 +2375,23 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
x, n = jnp.arange(3), jnp.arange(4)
|
||||
api.vmap(api.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash
|
||||
|
||||
def testAssociativeScanUnstructured1000(self):
|
||||
data = np.arange(1000)
|
||||
expected = np.cumsum(data)
|
||||
result = lax.associative_scan(operator.add, data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{shape}_axis={axis}",
|
||||
"shape": shape, "axis": axis}
|
||||
for shape in [
|
||||
[0], [1], [2], [3], [5], [10], [1000],
|
||||
[2, 3], [7, 5], [5, 6, 7]
|
||||
]
|
||||
for axis in range(-len(shape), len(shape) - 1))
|
||||
def testAssociativeScanUnstructured(self, shape, axis):
|
||||
data = np.arange(np.prod(shape)).reshape(shape) + 7
|
||||
expected = np.cumsum(data, axis=axis)
|
||||
result = lax.associative_scan(operator.add, data, axis=axis)
|
||||
self.assertAllClose(result, expected, check_dtypes=False)
|
||||
|
||||
def testAssociativeScanUnstructured1000Reverse(self):
|
||||
data = np.arange(1000)
|
||||
data = np.arange(1000) + 32
|
||||
expected = np.cumsum(data[::-1])[::-1]
|
||||
result = lax.associative_scan(operator.add, data, reverse=True)
|
||||
self.assertAllClose(result, expected, check_dtypes=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user