Add a reverse=... argument to lax.cumsum/cumprod/...

This allows us to lower to a more efficient implementation on TPU.
This commit is contained in:
Peter Hawkins 2020-10-16 10:09:11 -04:00
parent c298700191
commit e863103b0e
5 changed files with 55 additions and 45 deletions

View File

@ -244,12 +244,13 @@ deflinear(lax.reduce_window_sum_p)
deflinear(lax_fft.fft_p)
deflinear(xla.device_put_p)
def _cumulative_jet_rule(primals_in, series_in, *, axis: int,
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis),
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis,
reverse=reverse),
primals_in, series_in)
deflinear(lax_control_flow.cumsum_p)

View File

@ -2492,52 +2492,53 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
# Cumulative reductions.
def cumsum(operand: Array, axis: int) -> Array:
def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative sum along `axis`."""
return cumsum_p.bind(operand, axis=int(axis))
return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cumprod(operand: Array, axis: int) -> Array:
def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis))
return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cummax(operand: Array, axis: int) -> Array:
def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative maximum along `axis`."""
return cummax_p.bind(operand, axis=int(axis))
return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def cummin(operand: Array, axis: int) -> Array:
def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
"""Computes a cumulative minimum along `axis`."""
return cummin_p.bind(operand, axis=int(axis))
return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
def _cumred_shape_rule(x, *, axis: int):
def _cumred_shape_rule(x, *, axis: int, reverse: bool):
if axis < 0 or axis >= x.ndim:
raise ValueError(
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
return x.shape
def _cumsum_transpose_rule(t, *, axis: int):
return [lax.rev(cumsum(lax.rev(t, (axis,)), axis=axis), (axis,))]
def _cumsum_transpose_rule(t, *, axis: int, reverse: bool):
return [cumsum(t, axis=axis, reverse=not reverse)]
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
axis: int):
axis: int, reverse: bool):
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0)] * x.ndim
padding[axis] = (n - 1, 0)
padding[axis] = (0, n - 1) if reverse else (n - 1, 0)
strides = [1] * x.ndim
window_dims = [1] * x.ndim
window_dims[axis] = n
return window_reduce(x, window_dims, strides, padding)
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int,
reverse: bool):
operand, = batched_args
bdim, = batch_dims
axis = axis if axis < bdim else axis + 1
return prim.bind(operand, axis=axis), bdim
return prim.bind(operand, axis=axis, reverse=reverse), bdim
def _cumred_dtype_rule(name, operand, *args, **kw):
if not dtypes.issubdtype(operand.dtype, np.number):
@ -2579,12 +2580,13 @@ xla.translations[cummin_p] = xla.lower_fun(
xla.translations[cummax_p] = xla.lower_fun(
partial(associative_scan, lax.max), multiple_results=False)
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(associative_scan, combine_fn, axis=axis),
return api.jvp(partial(associative_scan, combine_fn, axis=axis,
reverse=reverse),
primals, tangents)
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)

View File

@ -743,10 +743,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
eps)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_axis={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
{"testcase_name": "_op={}_shape={}_axis={}_reverse={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
reverse),
"op": op, "shape": shape, "dtype": dtype,
"axis": axis, "rng_factory": rng_factory}
"axis": axis, "reverse": reverse}
for op, types in [
(lax.cumsum, [np.float32, np.float64]),
(lax.cumprod, [np.float32, np.float64]),
@ -754,12 +755,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
for dtype in types
for shape in [[10], [3, 4, 5]]
for axis in range(len(shape))
for rng_factory in [
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small]))
def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
for reverse in [False, True]))
def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small)
rng = rng_factory(self.rng())
check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)
check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),),
order=2)
# TODO(b/205052657): enable more tests when supported

View File

@ -1420,10 +1420,11 @@ class LaxTest(jtu.JaxTestCase):
self.assertEqual(shape, result.shape)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_axis={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
{"testcase_name": "_op={}_shape={}_axis={}_reverse={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
reverse),
"op": op, "np_op": np_op, "shape": shape, "dtype": dtype,
"axis": axis, "rng_factory": rng_factory}
"axis": axis, "reverse": reverse}
for op, np_op, types in [
(lax.cumsum, np.cumsum, default_dtypes),
(lax.cumprod, np.cumprod, default_dtypes),
@ -1433,13 +1434,17 @@ class LaxTest(jtu.JaxTestCase):
for dtype in types
for shape in [[10], [3, 4, 5]]
for axis in range(len(shape))
for rng_factory in [
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small]))
def testCumulativeReduce(self, op, np_op, shape, dtype, axis, rng_factory):
for reverse in [False, True]))
def testCumulativeReduce(self, op, np_op, shape, dtype, axis, reverse):
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small)
rng = rng_factory(self.rng())
fun = partial(op, axis=axis)
np_fun = partial(np_op, axis=axis, dtype=dtype)
fun = partial(op, axis=axis, reverse=reverse)
def np_fun(x):
if reverse:
return np.flip(np_op(np.flip(x, axis), axis=axis, dtype=dtype), axis)
else:
return np_op(x, axis=axis, dtype=dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CompileAndCheck(fun, args_maker)
self._CheckAgainstNumpy(np_fun, fun, args_maker)

View File

@ -541,11 +541,11 @@ class LaxVmapTest(jtu.JaxTestCase):
self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_shape={}_axis={}_bdims={}"
{"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
bdims),
bdims, reverse),
"op": op, "shape": shape, "dtype": dtype, "bdims": bdims,
"axis": axis, "rng_factory": rng_factory}
"axis": axis, "reverse": reverse}
for op, types in [
(lax.cumsum, [np.float32, np.float64]),
(lax.cumprod, [np.float32, np.float64]),
@ -554,13 +554,13 @@ class LaxVmapTest(jtu.JaxTestCase):
for shape in [[10], [3, 4, 5]]
for axis in range(len(shape))
for bdims in all_bdims(shape)
for rng_factory in [
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small]))
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory):
for reverse in [False, True]))
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
else jtu.rand_small)
rng = rng_factory(self.rng())
self._CheckBatching(partial(op, axis=axis), 7, bdims, (shape,), (dtype,),
rng)
self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
(shape,), (dtype,), rng)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name,