mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a reverse=... argument to lax.cumsum/cumprod/...
This allows us to lower to a more efficient implementation on TPU.
This commit is contained in:
parent
c298700191
commit
e863103b0e
@ -244,12 +244,13 @@ deflinear(lax.reduce_window_sum_p)
|
||||
deflinear(lax_fft.fft_p)
|
||||
deflinear(xla.device_put_p)
|
||||
|
||||
def _cumulative_jet_rule(primals_in, series_in, *, axis: int,
|
||||
def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
|
||||
combine_fn: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis),
|
||||
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis,
|
||||
reverse=reverse),
|
||||
primals_in, series_in)
|
||||
|
||||
deflinear(lax_control_flow.cumsum_p)
|
||||
|
@ -2492,52 +2492,53 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
|
||||
# Cumulative reductions.
|
||||
|
||||
def cumsum(operand: Array, axis: int) -> Array:
|
||||
def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
||||
"""Computes a cumulative sum along `axis`."""
|
||||
return cumsum_p.bind(operand, axis=int(axis))
|
||||
return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
||||
|
||||
def cumprod(operand: Array, axis: int) -> Array:
|
||||
def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
||||
"""Computes a cumulative product along `axis`."""
|
||||
return cumprod_p.bind(operand, axis=int(axis))
|
||||
return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
||||
|
||||
def cummax(operand: Array, axis: int) -> Array:
|
||||
def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
||||
"""Computes a cumulative maximum along `axis`."""
|
||||
return cummax_p.bind(operand, axis=int(axis))
|
||||
return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
||||
|
||||
def cummin(operand: Array, axis: int) -> Array:
|
||||
def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array:
|
||||
"""Computes a cumulative minimum along `axis`."""
|
||||
return cummin_p.bind(operand, axis=int(axis))
|
||||
return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse))
|
||||
|
||||
def _cumred_shape_rule(x, *, axis: int):
|
||||
def _cumred_shape_rule(x, *, axis: int, reverse: bool):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
||||
return x.shape
|
||||
|
||||
def _cumsum_transpose_rule(t, *, axis: int):
|
||||
return [lax.rev(cumsum(lax.rev(t, (axis,)), axis=axis), (axis,))]
|
||||
def _cumsum_transpose_rule(t, *, axis: int, reverse: bool):
|
||||
return [cumsum(t, axis=axis, reverse=not reverse)]
|
||||
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
|
||||
axis: int):
|
||||
axis: int, reverse: bool):
|
||||
# On TPU, an implementation using reduce_window is handled specially by the
|
||||
# compiler and is efficient. On other backends, it is O(n^2).
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
padding = [(0, 0)] * x.ndim
|
||||
padding[axis] = (n - 1, 0)
|
||||
padding[axis] = (0, n - 1) if reverse else (n - 1, 0)
|
||||
strides = [1] * x.ndim
|
||||
window_dims = [1] * x.ndim
|
||||
window_dims[axis] = n
|
||||
return window_reduce(x, window_dims, strides, padding)
|
||||
|
||||
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
|
||||
def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int,
|
||||
reverse: bool):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
axis = axis if axis < bdim else axis + 1
|
||||
return prim.bind(operand, axis=axis), bdim
|
||||
return prim.bind(operand, axis=axis, reverse=reverse), bdim
|
||||
|
||||
def _cumred_dtype_rule(name, operand, *args, **kw):
|
||||
if not dtypes.issubdtype(operand.dtype, np.number):
|
||||
@ -2579,12 +2580,13 @@ xla.translations[cummin_p] = xla.lower_fun(
|
||||
xla.translations[cummax_p] = xla.lower_fun(
|
||||
partial(associative_scan, lax.max), multiple_results=False)
|
||||
|
||||
def _cumulative_jvp_rule(primals, tangents, *, axis: int,
|
||||
def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
|
||||
combine_fn: Callable):
|
||||
# Irrespective of backend, we always use the parallel prefix scan
|
||||
# implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
return api.jvp(partial(associative_scan, combine_fn, axis=axis),
|
||||
return api.jvp(partial(associative_scan, combine_fn, axis=axis,
|
||||
reverse=reverse),
|
||||
primals, tangents)
|
||||
|
||||
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
|
||||
|
@ -743,10 +743,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
{"testcase_name": "_op={}_shape={}_axis={}_reverse={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
reverse),
|
||||
"op": op, "shape": shape, "dtype": dtype,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
"axis": axis, "reverse": reverse}
|
||||
for op, types in [
|
||||
(lax.cumsum, [np.float32, np.float64]),
|
||||
(lax.cumprod, [np.float32, np.float64]),
|
||||
@ -754,12 +755,13 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
|
||||
for reverse in [False, True]))
|
||||
def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
|
||||
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small)
|
||||
rng = rng_factory(self.rng())
|
||||
check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)
|
||||
check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),),
|
||||
order=2)
|
||||
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
|
@ -1420,10 +1420,11 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self.assertEqual(shape, result.shape)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
{"testcase_name": "_op={}_shape={}_axis={}_reverse={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
reverse),
|
||||
"op": op, "np_op": np_op, "shape": shape, "dtype": dtype,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
"axis": axis, "reverse": reverse}
|
||||
for op, np_op, types in [
|
||||
(lax.cumsum, np.cumsum, default_dtypes),
|
||||
(lax.cumprod, np.cumprod, default_dtypes),
|
||||
@ -1433,13 +1434,17 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduce(self, op, np_op, shape, dtype, axis, rng_factory):
|
||||
for reverse in [False, True]))
|
||||
def testCumulativeReduce(self, op, np_op, shape, dtype, axis, reverse):
|
||||
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small)
|
||||
rng = rng_factory(self.rng())
|
||||
fun = partial(op, axis=axis)
|
||||
np_fun = partial(np_op, axis=axis, dtype=dtype)
|
||||
fun = partial(op, axis=axis, reverse=reverse)
|
||||
def np_fun(x):
|
||||
if reverse:
|
||||
return np.flip(np_op(np.flip(x, axis), axis=axis, dtype=dtype), axis)
|
||||
else:
|
||||
return np_op(x, axis=axis, dtype=dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
self._CheckAgainstNumpy(np_fun, fun, args_maker)
|
||||
|
@ -541,11 +541,11 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}_bdims={}"
|
||||
{"testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
bdims),
|
||||
bdims, reverse),
|
||||
"op": op, "shape": shape, "dtype": dtype, "bdims": bdims,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
"axis": axis, "reverse": reverse}
|
||||
for op, types in [
|
||||
(lax.cumsum, [np.float32, np.float64]),
|
||||
(lax.cumprod, [np.float32, np.float64]),
|
||||
@ -554,13 +554,13 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for bdims in all_bdims(shape)
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory):
|
||||
for reverse in [False, True]))
|
||||
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
|
||||
rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
|
||||
else jtu.rand_small)
|
||||
rng = rng_factory(self.rng())
|
||||
self._CheckBatching(partial(op, axis=axis), 7, bdims, (shape,), (dtype,),
|
||||
rng)
|
||||
self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
|
||||
(shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name,
|
||||
|
Loading…
x
Reference in New Issue
Block a user