mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add backend-specific lowering for cumsum/cumprod on TPU. (#2614)
* Add backend-specific lowering for cumsum/cumprod on TPU. Make cumsum/cumprod primitives so they can have backend-specific lowerings. * Disable cumulative reduction gradient test on TPU.
This commit is contained in:
parent
f5f35c5c3b
commit
329321b0f1
@ -646,7 +646,13 @@ def add_jaxvals_translation_rule(c, x, y):
|
||||
return c.Add(x, y)
|
||||
translations[ad_util.add_jaxvals_p] = add_jaxvals_translation_rule
|
||||
|
||||
def lower_fun(fun):
|
||||
@lu.transformation
|
||||
def _tuple_output(*args, **kwargs):
|
||||
ans = yield args, kwargs
|
||||
yield (ans,)
|
||||
|
||||
|
||||
def lower_fun(fun, multiple_results=True):
|
||||
# This function can only be used to lower functions that take JAX array types
|
||||
# as arguments (and e.g. don't accept unit values), because it assumes it can
|
||||
# map from XLA types to JAX types. In general that mapping is not possible (as
|
||||
@ -658,11 +664,18 @@ def lower_fun(fun):
|
||||
# TODO(mattjj): revise this 'calling convention'
|
||||
avals = [_array_aval_from_xla_shape(c.GetShape(x)) for x in xla_args]
|
||||
pvals = [pe.PartialVal((a, core.unit)) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(
|
||||
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True)
|
||||
wrapped_fun = lu.wrap_init(fun, params)
|
||||
if not multiple_results:
|
||||
wrapped_fun = _tuple_output(wrapped_fun)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
|
||||
stage_out=True)
|
||||
consts = _map(c.Constant, consts)
|
||||
outs = jaxpr_subcomp(c, jaxpr, None, AxisEnv(1), consts, '', *xla_args)
|
||||
return c.Tuple(*outs)
|
||||
if multiple_results:
|
||||
return c.Tuple(*outs)
|
||||
else:
|
||||
assert len(outs) == 1, outs
|
||||
return outs[0]
|
||||
return f
|
||||
|
||||
def _array_aval_from_xla_shape(xla_shape):
|
||||
|
114
jax/lax/lax.py
114
jax/lax/lax.py
@ -1016,6 +1016,14 @@ def _select_and_gather_add(tangents, operand, select_prim, window_dimensions,
|
||||
window_dimensions=tuple(window_dimensions),
|
||||
window_strides=tuple(window_strides), padding=padding)
|
||||
|
||||
def cumsum(operand, axis: int):
|
||||
"""Computes a cumulative sum along `axis`."""
|
||||
return cumsum_p.bind(operand, axis=int(axis))
|
||||
|
||||
def cumprod(operand, axis: int):
|
||||
"""Computes a cumulative product along `axis`."""
|
||||
return cumprod_p.bind(operand, axis=int(axis))
|
||||
|
||||
def sort(operand, dimension=-1):
|
||||
"""Wraps XLA's `Sort
|
||||
<https://www.tensorflow.org/xla/operation_semantics#sort>`_
|
||||
@ -4041,6 +4049,112 @@ xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
|
||||
max_bits=32)
|
||||
|
||||
|
||||
# Parallel prefix-scan. See:
|
||||
# https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
|
||||
# and
|
||||
# Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report
|
||||
# CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.
|
||||
#
|
||||
# Unlike the Blelloch algorithm, we use an out-of-place algorithm that uses 2n
|
||||
# space. This is somewhat wasteful if we are interested only in the output of
|
||||
# the forward pass, but more memory-efficient if we intend to differentiate
|
||||
# through the implementation of the scan.
|
||||
def _prescan_power_of_two(x, axis: int, op: Callable, unit):
|
||||
n = x.shape[axis]
|
||||
assert n != 0 and n & (n - 1) == 0, "n must be a power of 2"
|
||||
|
||||
# Upsweep
|
||||
xs = []
|
||||
for d in range(0, n.bit_length() - 1):
|
||||
x1 = slice_in_dim(x, 0, None, stride=2, axis=axis)
|
||||
xs.append(x1)
|
||||
x2 = slice_in_dim(x, 1, None, stride=2, axis=axis)
|
||||
x = op(x1, x2)
|
||||
total = x
|
||||
|
||||
# Downsweep
|
||||
x = full_like(total, unit)
|
||||
pad_left = [(0, 0, 0)] * len(x.shape)
|
||||
pad_left[axis] = (1, 0, 1)
|
||||
pad_right = [(0, 0, 0)] * len(x.shape)
|
||||
pad_right[axis] = (0, 1, 1)
|
||||
for w in reversed(xs):
|
||||
x1 = pad(x, _const(x, 0), pad_right)
|
||||
x2 = pad(x, _const(x, 0), pad_left)
|
||||
w = pad(w, _const(x, 0), pad_left)
|
||||
x = x1 + op(x2, w)
|
||||
|
||||
return x, total
|
||||
|
||||
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
|
||||
n = x.shape[axis]
|
||||
if n == 0:
|
||||
return x
|
||||
# Pads to the next largest power of two
|
||||
nbits = n.bit_length()
|
||||
if n == (1 << (nbits - 1)):
|
||||
nbits -= 1
|
||||
padding = [(0, 0, 0)] * len(x.shape)
|
||||
padding[axis] = (0, (1 << nbits) - n, 0)
|
||||
x = pad(x, _const(x, unit), padding)
|
||||
x, total = _prescan_power_of_two(x, axis, op, unit)
|
||||
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)
|
||||
|
||||
def _cumred_shape_rule(x, axis):
|
||||
if axis < 0 or axis >= x.ndim:
|
||||
raise ValueError(
|
||||
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
|
||||
return x.shape
|
||||
|
||||
|
||||
def _cumred_jvp_rule(impl: Callable, primals, tangents, axis: int):
|
||||
return api.jvp(partial(impl, axis=axis), primals, tangents)
|
||||
|
||||
|
||||
def _cumred_tpu_translation_rule(window_reduce: Callable, unit, x, axis: int):
|
||||
# On TPU, an implementation using reduce_window is handled specially by the
|
||||
# compiler. However, irrespective of backend, we always use the parallel
|
||||
# prefix scan implementation when differentiating because reduce_window is not
|
||||
# arbitrarily differentiable.
|
||||
n = x.shape[axis]
|
||||
padding = [(0, 0, 0)] * x.ndim
|
||||
padding[axis] = (n - 1, 0, 0)
|
||||
x = pad(x, _const(x, unit), padding)
|
||||
strides = [1] * x.ndim
|
||||
window_dims = [1] * x.ndim
|
||||
window_dims[axis] = n
|
||||
return window_reduce(x, window_dims, strides, xla_client.PaddingType.VALID)
|
||||
|
||||
def _cumred_batch_rule(prim, batched_args, batch_dims, axis: int):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
axis = axis if axis < bdim else axis + 1
|
||||
return prim.bind(operand, axis=axis), bdim
|
||||
|
||||
|
||||
_cumsum_impl = partial(_parallel_prefix_scan, op=add, unit=0)
|
||||
|
||||
cumsum_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumsum"),
|
||||
'cumsum', xla.lower_fun(_cumsum_impl, multiple_results=False))
|
||||
ad.primitive_jvps[cumsum_p] = partial(_cumred_jvp_rule, _cumsum_impl)
|
||||
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_sum, 0),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)
|
||||
|
||||
_cumprod_impl= partial(_parallel_prefix_scan, op=mul, unit=1)
|
||||
|
||||
cumprod_p = standard_primitive(
|
||||
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumprod"),
|
||||
'cumprod', xla.lower_fun(_cumprod_impl, multiple_results=False))
|
||||
ad.primitive_jvps[cumprod_p] = partial(_cumred_jvp_rule, _cumprod_impl)
|
||||
xla.backend_specific_translations['tpu'][cumprod_p] = xla.lower_fun(
|
||||
partial(_cumred_tpu_translation_rule, _reduce_window_prod, 1),
|
||||
multiple_results=False)
|
||||
batching.primitive_batchers[cumprod_p] = partial(_cumred_batch_rule, cumprod_p)
|
||||
|
||||
sort_shape = lambda operand, dimension: operand.shape
|
||||
|
||||
def _sort_jvp_rule(g, operand, dimension):
|
||||
|
@ -1571,59 +1571,7 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=False):
|
||||
return td
|
||||
|
||||
|
||||
# Parallel prefix-scan. See:
|
||||
# https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
|
||||
# and
|
||||
# Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report
|
||||
# CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.
|
||||
#
|
||||
# Unlike the Blelloch algorithm, we use an out-of-place algorithm that uses 2n
|
||||
# space. This is somewhat wasteful if we are interested only in the output of
|
||||
# the forward pass, but more memory-efficient if we intend to differentiate
|
||||
# through the implementation of the scan.
|
||||
def _prescan_power_of_two(x, axis: int, op: Callable, unit):
|
||||
n = x.shape[axis]
|
||||
assert n != 0 and n & (n - 1) == 0, "n must be a power of 2"
|
||||
|
||||
# Upsweep
|
||||
xs = []
|
||||
for d in range(0, n.bit_length() - 1):
|
||||
x1 = lax.slice_in_dim(x, 0, None, stride=2, axis=axis)
|
||||
xs.append(x1)
|
||||
x2 = lax.slice_in_dim(x, 1, None, stride=2, axis=axis)
|
||||
x = op(x1, x2)
|
||||
total = x
|
||||
|
||||
# Downsweep
|
||||
x = full_like(total, unit)
|
||||
pad_left = [(0, 0, 0)] * len(x.shape)
|
||||
pad_left[axis] = (1, 0, 1)
|
||||
pad_right = [(0, 0, 0)] * len(x.shape)
|
||||
pad_right[axis] = (0, 1, 1)
|
||||
for w in reversed(xs):
|
||||
x1 = lax.pad(x, x.dtype.type(0), pad_right)
|
||||
x2 = lax.pad(x, x.dtype.type(0), pad_left)
|
||||
w = lax.pad(w, x.dtype.type(0), pad_left)
|
||||
x = x1 + op(x2, w)
|
||||
|
||||
return x, total
|
||||
|
||||
def _parallel_prefix_scan(x, axis: int, op: Callable, unit):
|
||||
n = x.shape[axis]
|
||||
|
||||
# Pads to the next largest power of two
|
||||
nbits = n.bit_length()
|
||||
if n == (1 << (nbits - 1)):
|
||||
nbits -= 1
|
||||
padding = [(0, 0, 0)] * len(x.shape)
|
||||
padding[axis] = (0, (1 << nbits) - n, 0)
|
||||
x = lax.pad(x, x.dtype.type(unit), padding)
|
||||
x, product = _prescan_power_of_two(x, axis, op, unit)
|
||||
return concatenate((lax.slice_in_dim(x, 1, n, axis=axis), product), axis=axis)
|
||||
|
||||
|
||||
def _make_cumulative_reduction(onp_reduction, op, unit,
|
||||
squash_nan=False):
|
||||
def _make_cumulative_reduction(onp_reduction, reduction, squash_nan=False):
|
||||
# We want to allow XLA to fuse the pad and reduce-window operators to
|
||||
# avoid materializing the padded output.
|
||||
# Consider removing `jit` once again if reduce-window is generalized to
|
||||
@ -1652,9 +1600,7 @@ def _make_cumulative_reduction(onp_reduction, op, unit,
|
||||
if dtype:
|
||||
a = lax.convert_element_type(a, dtype)
|
||||
|
||||
if a_shape[axis] == 0:
|
||||
return a
|
||||
return _parallel_prefix_scan(a, axis, op, unit)
|
||||
return reduction(a, axis)
|
||||
|
||||
@_wraps(onp_reduction)
|
||||
def cumulative_reduction(a, axis=None, dtype=None):
|
||||
@ -1663,15 +1609,13 @@ def _make_cumulative_reduction(onp_reduction, op, unit,
|
||||
return cumulative_reduction
|
||||
|
||||
|
||||
cumsum = _make_cumulative_reduction(
|
||||
onp.cumsum, add, 0, squash_nan=False)
|
||||
cumprod = _make_cumulative_reduction(
|
||||
onp.cumprod, multiply, 1, squash_nan=False)
|
||||
cumsum = _make_cumulative_reduction(onp.cumsum, lax.cumsum, squash_nan=False)
|
||||
cumprod = _make_cumulative_reduction(onp.cumprod, lax.cumprod, squash_nan=False)
|
||||
cumproduct = cumprod
|
||||
nancumsum = _make_cumulative_reduction(
|
||||
onp.nancumsum, add, 0, squash_nan=True)
|
||||
nancumprod = _make_cumulative_reduction(
|
||||
onp.nancumprod, multiply, 1, squash_nan=True)
|
||||
nancumsum = _make_cumulative_reduction(onp.nancumsum, lax.cumsum,
|
||||
squash_nan=True)
|
||||
nancumprod = _make_cumulative_reduction(onp.nancumprod, lax.cumprod,
|
||||
squash_nan=True)
|
||||
|
||||
|
||||
### Array-creation functions
|
||||
|
@ -1125,9 +1125,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
|
||||
grad_dtypes = [onp.float32, onp.float64, onp.complex64, onp.complex128]
|
||||
if dtype in grad_dtypes and out_dtype in grad_dtypes:
|
||||
check_grads(jnp_fun, args_maker(), order=2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_m={}_n={}_k={}".format(
|
||||
|
@ -1269,6 +1269,29 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
"op": op, "onp_op": onp_op, "shape": shape, "dtype": dtype,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
for op, onp_op, types in [
|
||||
(lax.cumsum, onp.cumsum, default_dtypes),
|
||||
(lax.cumprod, onp.cumprod, default_dtypes),
|
||||
]
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduce(self, op, onp_op, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory()
|
||||
fun = partial(op, axis=axis)
|
||||
onp_fun = partial(onp_op, axis=axis)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
self._CheckAgainstNumpy(fun, onp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
@ -2385,6 +2408,27 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
|
||||
eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
"op": op, "shape": shape, "dtype": dtype,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
for op, types in [
|
||||
(lax.cumsum, [onp.float32, onp.float64]),
|
||||
(lax.cumprod, [onp.float32, onp.float64]),
|
||||
]
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
||||
else jtu.rand_small]))
|
||||
@jtu.skip_on_devices("tpu") # TODO(b/153183305): wrong outputs
|
||||
def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
|
||||
rng = rng_factory()
|
||||
check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)
|
||||
|
||||
|
||||
# TODO(b/205052657): enable more tests when supported
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}".format(
|
||||
@ -3025,6 +3069,28 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for bdims in all_bdims(shape):
|
||||
self._CheckBatching(fun, 3, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_shape={}_axis={}_bdims={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
bdims),
|
||||
"op": op, "shape": shape, "dtype": dtype, "bdims": bdims,
|
||||
"axis": axis, "rng_factory": rng_factory}
|
||||
for op, types in [
|
||||
(lax.cumsum, [onp.float32, onp.float64]),
|
||||
(lax.cumprod, [onp.float32, onp.float64]),
|
||||
]
|
||||
for dtype in types
|
||||
for shape in [[10], [3, 4, 5]]
|
||||
for axis in range(len(shape))
|
||||
for bdims in all_bdims(shape)
|
||||
for rng_factory in [
|
||||
jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
|
||||
else jtu.rand_small]))
|
||||
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory):
|
||||
rng = rng_factory()
|
||||
self._CheckBatching(partial(op, axis=axis), 7, bdims, (shape,), (dtype,),
|
||||
rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_dtype={}_padding={}".format(onp.dtype(dtype).name,
|
||||
padding),
|
||||
|
Loading…
x
Reference in New Issue
Block a user