Add a JVP rule for the general case of lax.reduce.

This commit is contained in:
Peter Hawkins 2021-03-23 10:31:02 -04:00
parent 634397dc59
commit 3fc1fdb148
3 changed files with 77 additions and 33 deletions

View File

@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* {func}`jax.lax.reduce` is now differentiable.
* Breaking changes:
* The minimum jaxlib version is now 0.1.64.
* Some profiler APIs names have been changed. There are still aliases, so this
@ -41,6 +42,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`.
* [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`
* Bug fixes:
* [#6136](https://github.com/google/jax/pull/6136) generalized
`jax.flatten_util.ravel_pytree` to handle integer dtypes.

View File

@ -4993,6 +4993,52 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
out_nodes = xops.Tuple(subc, out_nodes)
return subc.build(out_nodes)
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
input_shape = np.array(primals[0].shape)
n = np.prod(input_shape[list(axes)])
non_axes = np.delete(np.arange(len(input_shape)), axes)
# Move the reduced axes to the front, and flatten them to 1D.
permutation = axes + tuple(non_axes)
new_shape = (n,) + tuple(input_shape[non_axes])
primals = tuple(reshape(x, new_shape, permutation) for x in primals)
tangents = tuple(reshape(t, new_shape, permutation) for t in tangents)
for d in range(len(non_axes) + 1):
reducer = api.vmap(reducer)
def _reduce_tree(*xs, axis=0):
"""Reduce by repeatedly splitting the array and multiplying."""
while xs[0].shape[axis] > 1:
n = xs[0].shape[axis]
n1 = (n + 1) // 2
n2 = n - n1
xs1 = [slice_in_dim(x, 0, n1) for x in xs]
xs2 = [slice_in_dim(x, n1, None) for x in xs]
if n2 != n1:
paddings = [(0, 0, 0)] * len(xs[0].shape)
paddings[axis] = (0, 1, 0)
xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)]
xs = reducer(*(xs1 + xs2))
if xs[0].shape[axis] == 0:
return [full(input_shape[non_axes], i) for i in init_values]
return tuple(squeeze(x, (axis,)) for x in xs)
return api.jvp(_reduce_tree, primals, tangents)
def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr,
consts, dimensions):
primal_xs, init_values = split_list(primals, [len(primals) // 2])
tangent_xs, tangent_init = split_list(tangents, [len(tangents) // 2])
# This test may be too strict, if a value is actually zero but we cannot prove
# it is symbolically zero.
if any(type(t) is not ad_util.Zero for t in tangent_init):
raise NotImplementedError(
"Gradient of general lax.reduce with non-zero tangents for "
"initial values to reduction not implemented")
reducer = core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts))
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
def _masking_defreducer(prim, identity):
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
@ -5030,7 +5076,7 @@ reduce_p.def_abstract_eval(
_reduce_named_shape_rule))
xla.translations[reduce_p] = _reduce_translation_rule
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
def _reduce_number_dtype_rule(name, operand, *args, **kw):
if not dtypes.issubdtype(operand.dtype, np.number):
@ -5081,38 +5127,10 @@ def _reduce_prod_translation_rule(c, operand, *, axes):
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
operand, = primals
tangent, = tangents
input_shape = np.array(operand.shape)
n = np.prod(input_shape[list(axes)])
non_axes = np.delete(np.arange(len(input_shape)), axes)
# Move the reduced axes to the front, and flatten them to 1D.
permutation = axes + tuple(non_axes)
new_shape = (n,) + tuple(input_shape[non_axes])
operand = reshape(operand, new_shape, permutation)
tangent = reshape(tangent, new_shape, permutation)
def _reduce_prod_tree(x, axis=0):
"""Reduce by repeatedly splitting the array and multiplying."""
while x.shape[axis] > 1:
n = x.shape[axis]
n1 = (n + 1) // 2
n2 = n - n1
x1 = slice_in_dim(x, 0, n1)
x2 = slice_in_dim(x, n1, None)
if n2 != n1:
paddings = [(0, 0, 0)] * len(x.shape)
paddings[axis] = (0, 1, 0)
x2 = pad(x2, _const(x, 1), paddings)
x = x1 * x2
if x.shape[axis] == 0:
return full(input_shape[non_axes], _one(x))
return squeeze(x, (axis,))
return api.jvp(_reduce_prod_tree, (operand,), (tangent,))
reducer = lambda x, y: [mul(x, y)]
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
primals, tangents, axes)
return primals_out[0], tangents_out[0]
reduce_prod_p = standard_primitive(
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),

View File

@ -654,6 +654,30 @@ class LaxAutodiffTest(jtu.JaxTestCase):
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_reducedims={}"
.format(jtu.format_shape_dtype_string(shape, dtype), dims),
"shape": shape, "dtype": dtype, "dims": dims}
for dtype in grad_float_dtypes
for shape, dims in [
[(3, 4, 5), ()],
[(3, 4, 5), (0,)],
[(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)],
[(3, 4, 5), (0, 1, 2)],
[(3, 1), (1,)],
[(3, 0, 5), (1,)],
]))
def testReducePairGrad(self, shape, dtype, dims):
rng = jtu.rand_default(self.rng(), scale=1)
tol = {np.float32: 1e-2, np.float64: 1e-4}
operands = (rng(shape, dtype), rng(shape, dtype))
init_vals = (np.array(0, dtype), np.array(1, dtype))
def op(xs, ys):
return (xs[0] + ys[0], xs[1] * ys[1])
reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
"_basedilation={}_windowdilation={}")