diff --git a/CHANGELOG.md b/CHANGELOG.md index 01549df98..dd8a2feb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * New features * New profiling APIs: {func}`jax.profiler.start_trace`, {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace` + * {func}`jax.lax.reduce` is now differentiable. * Breaking changes: * The minimum jaxlib version is now 0.1.64. * Some profiler APIs names have been changed. There are still aliases, so this @@ -41,6 +42,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. `jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`, `jax.debug_infs`, `jax.log_compiles`. * [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete` + * Bug fixes: * [#6136](https://github.com/google/jax/pull/6136) generalized `jax.flatten_util.ravel_pytree` to handle integer dtypes. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6badd94d1..e0664d214 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4993,6 +4993,52 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True): out_nodes = xops.Tuple(subc, out_nodes) return subc.build(out_nodes) +def _reduce_jvp(reducer, init_values, primals, tangents, axes): + input_shape = np.array(primals[0].shape) + + n = np.prod(input_shape[list(axes)]) + non_axes = np.delete(np.arange(len(input_shape)), axes) + + # Move the reduced axes to the front, and flatten them to 1D. + permutation = axes + tuple(non_axes) + new_shape = (n,) + tuple(input_shape[non_axes]) + primals = tuple(reshape(x, new_shape, permutation) for x in primals) + tangents = tuple(reshape(t, new_shape, permutation) for t in tangents) + + for d in range(len(non_axes) + 1): + reducer = api.vmap(reducer) + def _reduce_tree(*xs, axis=0): + """Reduce by repeatedly splitting the array and multiplying.""" + while xs[0].shape[axis] > 1: + n = xs[0].shape[axis] + n1 = (n + 1) // 2 + n2 = n - n1 + xs1 = [slice_in_dim(x, 0, n1) for x in xs] + xs2 = [slice_in_dim(x, n1, None) for x in xs] + if n2 != n1: + paddings = [(0, 0, 0)] * len(xs[0].shape) + paddings[axis] = (0, 1, 0) + xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)] + xs = reducer(*(xs1 + xs2)) + if xs[0].shape[axis] == 0: + return [full(input_shape[non_axes], i) for i in init_values] + return tuple(squeeze(x, (axis,)) for x in xs) + + return api.jvp(_reduce_tree, primals, tangents) + +def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, + consts, dimensions): + primal_xs, init_values = split_list(primals, [len(primals) // 2]) + tangent_xs, tangent_init = split_list(tangents, [len(tangents) // 2]) + # This test may be too strict, if a value is actually zero but we cannot prove + # it is symbolically zero. + if any(type(t) is not ad_util.Zero for t in tangent_init): + raise NotImplementedError( + "Gradient of general lax.reduce with non-zero tangents for " + "initial values to reduction not implemented") + reducer = core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)) + return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions) + def _masking_defreducer(prim, identity): masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity) @@ -5030,7 +5076,7 @@ reduce_p.def_abstract_eval( _reduce_named_shape_rule)) xla.translations[reduce_p] = _reduce_translation_rule batching.primitive_batchers[reduce_p] = _reduce_batch_rule - +ad.primitive_jvps[reduce_p] = _reduce_jvp_rule def _reduce_number_dtype_rule(name, operand, *args, **kw): if not dtypes.issubdtype(operand.dtype, np.number): @@ -5081,38 +5127,10 @@ def _reduce_prod_translation_rule(c, operand, *, axes): xla.primitive_subcomputation(mul_p, scalar, scalar), axes) def _reduce_prod_jvp_rule(primals, tangents, *, axes): - operand, = primals - tangent, = tangents - input_shape = np.array(operand.shape) - - n = np.prod(input_shape[list(axes)]) - non_axes = np.delete(np.arange(len(input_shape)), axes) - - # Move the reduced axes to the front, and flatten them to 1D. - permutation = axes + tuple(non_axes) - new_shape = (n,) + tuple(input_shape[non_axes]) - operand = reshape(operand, new_shape, permutation) - tangent = reshape(tangent, new_shape, permutation) - - def _reduce_prod_tree(x, axis=0): - """Reduce by repeatedly splitting the array and multiplying.""" - while x.shape[axis] > 1: - n = x.shape[axis] - n1 = (n + 1) // 2 - n2 = n - n1 - x1 = slice_in_dim(x, 0, n1) - x2 = slice_in_dim(x, n1, None) - if n2 != n1: - paddings = [(0, 0, 0)] * len(x.shape) - paddings[axis] = (0, 1, 0) - x2 = pad(x2, _const(x, 1), paddings) - x = x1 * x2 - if x.shape[axis] == 0: - return full(input_shape[non_axes], _one(x)) - return squeeze(x, (axis,)) - - return api.jvp(_reduce_prod_tree, (operand,), (tangent,)) - + reducer = lambda x, y: [mul(x, y)] + primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)], + primals, tangents, axes) + return primals_out[0], tangents_out[0] reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 92d496b12..30408ab49 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -654,6 +654,30 @@ class LaxAutodiffTest(jtu.JaxTestCase): if op not in (lax.max, lax.min) or all(d > 0 for d in shape): check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_reducedims={}" + .format(jtu.format_shape_dtype_string(shape, dtype), dims), + "shape": shape, "dtype": dtype, "dims": dims} + for dtype in grad_float_dtypes + for shape, dims in [ + [(3, 4, 5), ()], + [(3, 4, 5), (0,)], + [(3, 4, 5), (1, 2)], + [(3, 4, 5), (0, 2)], + [(3, 4, 5), (0, 1, 2)], + [(3, 1), (1,)], + [(3, 0, 5), (1,)], + ])) + def testReducePairGrad(self, shape, dtype, dims): + rng = jtu.rand_default(self.rng(), scale=1) + tol = {np.float32: 1e-2, np.float64: 1e-4} + operands = (rng(shape, dtype), rng(shape, dtype)) + init_vals = (np.array(0, dtype), np.array(1, dtype)) + def op(xs, ys): + return (xs[0] + ys[0], xs[1] * ys[1]) + reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims) + check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}")