diff --git a/CHANGELOG.md b/CHANGELOG.md index 93a063a87..8edd7a8b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main). * New features: * New SciPy function {py:func}`jax.scipy.special.sph_harm`. + * Reverse-mode autodiff functions ({func}`jax.grad`, + {func}`jax.value_and_grad`, {func}`jax.vjp`, and + {func}`jax.linear_transpose`) support a parameter that indicates which named + axes should be summed over in the backward pass if they were broadcasted + over in the forward pass. This enables use of these APIs in a + non-per-example way inside maps (initially only + {func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`). ## jaxlib 0.1.69 (unreleased) diff --git a/jax/_src/api.py b/jax/_src/api.py index 430f4ed9c..94561f6d2 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -578,9 +578,9 @@ def xla_computation(fun: Callable, return_shape: Optional boolean, defaults to ``False``. If ``True``, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as - the output of ``fun`` and where the leaves are objects with ``shape`` and - ``dtype`` attributes representing the corresponding types of the output - leaves. + the output of ``fun`` and where the leaves are objects with ``shape``, + ``dtype``, and ``named_shape`` attributes representing the corresponding + types of the output leaves. donate_argnums: Specify which arguments are "donated" to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated @@ -597,7 +597,7 @@ def xla_computation(fun: Callable, ``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the wrapped function returns a pair where the first element is the XLA Computation and the second element is a pytree representing the structure, - shapes, and dtypes of the output of ``fun``. + shapes, dtypes, and named shapes of the output of ``fun``. Concrete example arguments are not always necessary. For those arguments not indicated by ``static_argnums``, any object with ``shape`` and ``dtype`` @@ -750,7 +750,8 @@ def xla_computation(fun: Callable, shapes = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d] warn(f"Some donated buffers were not usable: {', '.join(shapes)}") built = c.build(out_tuple) - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] + out_shapes_flat = [ + ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] out_shape = tree_unflatten(out_tree(), out_shapes_flat) for out_aval in out_avals: if not isinstance(out_aval, xla.ShapedArray): @@ -767,7 +768,8 @@ def xla_computation(fun: Callable, def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, - allow_int: bool = False) -> Callable: + allow_int: bool = False, + reduce_axes: Sequence[AxisName] = ()) -> Callable: """Creates a function which evaluates the gradient of ``fun``. Args: @@ -787,6 +789,13 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fun`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + gradient will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a + function that computes the total gradient while ``grad(f)`` will create + one that computes the per-example gradient. Returns: A function with the same arguments as ``fun``, that evaluates the gradient @@ -806,7 +815,8 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, """ value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux, holomorphic=holomorphic, - allow_int=allow_int) + allow_int=allow_int, + reduce_axes=reduce_axes) docstr = ("Gradient of {fun} with respect to positional argument(s) " "{argnums}. Takes the same arguments as {fun} but returns the " @@ -829,7 +839,8 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, - allow_int: bool = False) -> Callable[..., Tuple[Any, Any]]: + allow_int: bool = False, reduce_axes: Sequence[AxisName] = () +) -> Callable[..., Tuple[Any, Any]]: """Create a function which evaluates both ``fun`` and the gradient of ``fun``. Args: @@ -847,6 +858,14 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fun`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + gradient will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``value_and_grad(f, reduce_axes=('batch',))`` will + create a function that computes the total gradient while + ``value_and_grad(f)`` will create one that computes the per-example + gradient. Returns: A function with the same arguments as ``fun`` that evaluates both ``fun`` @@ -879,9 +898,10 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, f_partial, dyn_args = argnums_partial(f, argnums, args) tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args) if not has_aux: - ans, vjp_py = _vjp(f_partial, *dyn_args) + ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes) else: - ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) + ans, vjp_py, aux = _vjp( + f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes) _check_scalar(ans) dtype = dtypes.result_type(ans) tree_map(partial(_check_output_dtype_grad, holomorphic), ans) @@ -1891,12 +1911,14 @@ if sys.version_info >= (3, 8): @overload # type: ignore def vjp(fun: Callable[..., T], *primals: Any, - has_aux: Literal[False] = False) -> Tuple[T, Callable]: + has_aux: Literal[False] = False, + reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]: ... @overload def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any, - has_aux: Literal[True]) -> Tuple[T, Callable, U]: + has_aux: Literal[True], + reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]: ... else: @@ -1907,12 +1929,14 @@ else: @overload def vjp( fun: Callable[..., Any], *primals: Any, - has_aux: bool) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]: + has_aux: bool, + reduce_axes: Sequence[AxisName] = () + ) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]: ... def vjp( # type: ignore - fun: Callable, *primals, has_aux: bool = False, + fun: Callable, *primals, has_aux: bool = False, reduce_axes=() ) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]: """Compute a (reverse-mode) vector-Jacobian product of ``fun``. @@ -1929,6 +1953,13 @@ def vjp( # type: ignore has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fun`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + VJP will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will + create a VJP function that sums over the batch while ``vjp(f, *args)`` + will create a per-example VJP. Returns: If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where @@ -1953,19 +1984,22 @@ def vjp( # type: ignore -0.2524413 """ _check_callable(fun) - return _vjp(lu.wrap_init(fun), *primals, has_aux=has_aux) + return _vjp( + lu.wrap_init(fun), *primals, has_aux=has_aux, reduce_axes=reduce_axes) -def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): +def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()): """Variant of vjp() that takes an lu.WrappedFun.""" primals_flat, in_tree = tree_flatten(primals) for arg in primals_flat: _check_arg(arg) if not has_aux: flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primal, out_vjp = ad.vjp(flat_fun, primals_flat) + out_primal, out_vjp = ad.vjp( + flat_fun, primals_flat, reduce_axes=reduce_axes) out_tree = out_tree() else: flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree) - out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) + out_primal, out_vjp, aux = ad.vjp( + flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes) out_tree, aux_tree = out_aux_trees() out_primal_py = tree_unflatten(out_tree, out_primal) ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal] @@ -1981,7 +2015,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) -def linear_transpose(fun: Callable, *primals) -> Callable: +def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: """Transpose a function that is promised to be linear. For linear functions, this transformation is equivalent to ``vjp``, but @@ -2002,6 +2036,14 @@ def linear_transpose(fun: Callable, *primals) -> Callable: is not required: only the ``shape`` and ``dtype`` attributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.) + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fun`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding cotangent. Otherwise, the + transposed function will be per-example over named axes. For example, if + ``'batch'`` is a named batch axis, ``linear_transpose(f, *args, + reduce_axes=('batch',))`` will create a transpose function that sums over + the batch while ``linear_transpose(f, args)`` will create a per-example + transpose. Returns: A callable that calculates the transpose of ``fun``. Valid input into this @@ -2046,7 +2088,7 @@ def linear_transpose(fun: Callable, *primals) -> Callable: dummies = [ad.UndefinedPrimal(a) for a in in_avals] in_cotangents = map( ad.instantiate_zeros, - ad.backward_pass(jaxpr, consts, dummies, out_cotangents)) + ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents)) return tree_unflatten(in_tree, in_cotangents) return transposed_fun @@ -2071,11 +2113,11 @@ def make_jaxpr(fun: Callable, specifies the axis name/size environment that would be set up by applications of :py:func:`jax.pmap`. return_shape: Optional boolean, defaults to ``False``. If ``True``, the - wrapped function returns a pair where the first element is the ``jaxpr`` - and the second element is a pytree with the same structure as - the output of ``fun`` and where the leaves are objects with ``shape`` and - ``dtype`` attributes representing the corresponding types of the output - leaves. + wrapped function returns a pair where the first element is the XLA + computation and the second element is a pytree with the same structure as + the output of ``fun`` and where the leaves are objects with ``shape``, + ``dtype``, and ``named_shape`` attributes representing the corresponding + types of the output leaves. Returns: A wrapped version of ``fun`` that when applied to example arguments returns @@ -2083,7 +2125,7 @@ def make_jaxpr(fun: Callable, argument ``return_shape`` is ``True``, then the returned function instead returns a pair where the first element is the ``ClosedJaxpr`` representation of ``fun`` and the second element is a pytree representing - the structure, shape, and dtypes of the output of ``fun``. + the structure, shape, dtypes, and named shapes of the output of ``fun``. A ``jaxpr`` is JAX's intermediate representation for program traces. The ``jaxpr`` language is based on the simply-typed first-order lambda calculus @@ -2135,7 +2177,8 @@ def make_jaxpr(fun: Callable, jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) if return_shape: - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] + out_shapes_flat = [ + ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat) return closed_jaxpr diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 20d330571..2e0c30d14 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -369,11 +369,12 @@ xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \ # If a (multi)linear function is defined with a custom jvp, then # custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's # already been linearized, we can drop the jvp rule. -def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk, - num_consts): +def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr, + jvp_jaxpr_thunk, num_consts): del jvp_jaxpr_thunk, num_consts - return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts) -ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose + return ad.backward_pass( + fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts) +ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose ### VJPs diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 7f76736a8..878ad76df 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -1020,7 +1020,7 @@ def _ordered_unique(xs): d = collections.OrderedDict((x, None) for x in xs) return list(d.keys()) -def _transpose_cond_jaxpr(jaxpr, num_res): +def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes): res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) primal_avals = _map(raise_to_shaped, primal_avals) @@ -1029,19 +1029,19 @@ def _transpose_cond_jaxpr(jaxpr, num_res): res, cts_out = split_list(args, [num_res]) primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] cts_in = ad.backward_pass( - jaxpr.jaxpr, jaxpr.consts, primals, cts_out) + jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out) _, cts_in = split_list(cts_in, [num_res]) return _map(ad.instantiate_zeros_aval, primal_avals, cts_in) return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals) -def _cond_transpose(cts, *args, branches, linear): +def _cond_transpose(reduce_axes, cts, *args, branches, linear): index, *ops = args in_avals = _map(raise_to_shaped, branches[0].in_avals) num_res = len(ops) - sum(linear) branches_trans = tuple( - _transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches) + _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches) lin_in_avals = [raise_to_shaped(a, weak_type=False) for a, l in zip(in_avals, linear) if l] assert all(core.typematch(out_aval, lin_in_aval) @@ -1131,7 +1131,7 @@ cond_p.def_impl(partial(xla.apply_primitive, cond_p)) cond_p.def_abstract_eval(_cond_abstract_eval) cond_p.def_custom_bind(cond_bind) ad.primitive_jvps[cond_p] = _cond_jvp -ad.primitive_transposes[cond_p] = _cond_transpose +ad.reducing_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval batching.initial_style_batchers[cond_p] = _cond_batching_rule xla.initial_style_translations[cond_p] = _cond_translation_rule @@ -1673,8 +1673,8 @@ def _maybe_device_put(x): else: return x -def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, - linear, unroll): +def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts, + num_carry, jaxpr, linear, unroll): # we've only implemented transposing scans with specific lin/nonlin patterns consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) num_ires = len(consts_lin) - sum(consts_lin) @@ -1702,7 +1702,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) jaxpr_trans = _transpose_scan_jaxpr( - num_ires, num_consts - num_ires, num_eres, jaxpr) + num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes) linear_trans = ([False] * num_ires + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) @@ -1717,8 +1717,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr, # transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) # -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) -def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr): +def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes): num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 + # TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals + # if an axis isn't reduced res1_avals, c_avals, a_avals, res2_avals = split_list( jaxpr.in_avals, [num_res1, num_c, num_a]) num_b = len(jaxpr.out_avals) @@ -1730,7 +1732,8 @@ def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr): res1_cbar_bbar_res2, [num_res1, num_c, num_b]) primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) - cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar) + cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts, + primals, b_bar) _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) c_bar = _map(ad.instantiate_zeros_aval, c_avals, @@ -1879,7 +1882,7 @@ scan_p.def_custom_bind(scan_bind) scan_p.def_impl(partial(xla.apply_primitive, scan_p)) scan_p.def_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp -ad.primitive_transposes[scan_p] = _scan_transpose +ad.reducing_transposes[scan_p] = _scan_transpose pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl) batching.initial_style_batchers[scan_p] = _scan_batching_rule diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 31f7359fb..587f9421a 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -17,13 +17,14 @@ import functools import itertools as it from typing import Any, Callable, Dict +import jax from . import partial_eval as pe from ..config import config from .. import core from .._src.dtypes import dtype, float0 from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, raise_to_shaped) -from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, +from .._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval, zeros_like_p, Zero) from .._src.util import (unzip2, safe_map, safe_zip, partial, split_list, wrap_name, as_hashable_function) @@ -31,7 +32,7 @@ from ..tree_util import register_pytree_node from .. import linear_util as lu from ..api_util import flatten_fun, flatten_fun_nokwargs from ..tree_util import tree_flatten, tree_unflatten, Partial -from jax._src import source_info_util +from .._src import source_info_util zip = safe_zip map = safe_map @@ -109,7 +110,7 @@ def linearize(traceable, *primals, **kwargs): else: return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux() -def vjp(traceable, primals, has_aux=False): +def vjp(traceable, primals, has_aux=False, reduce_axes=()): if not has_aux: out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) else: @@ -118,7 +119,7 @@ def vjp(traceable, primals, has_aux=False): def unbound_vjp(pvals, jaxpr, consts, *cts): cts = tuple(map(ignore_consts, cts, pvals)) dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars] - arg_cts = backward_pass(jaxpr, consts, dummy_args, cts) + arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts) return map(instantiate_zeros, arg_cts) # Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward @@ -160,7 +161,7 @@ def recast_to_float0(primal, tangent): return tangent # NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will) -def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): +def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) @@ -173,6 +174,11 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return + axes_to_reduce = tuple(axis_name for axis_name in reduce_axes + if axis_name in core.get_aval(ct).named_shape + and axis_name not in v.aval.named_shape) + if axes_to_reduce: + ct = jax.lax.psum(ct, axis_name=axes_to_reduce) ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) @@ -213,7 +219,10 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): cts_in_avals = [v.aval for v in eqn.outvars] call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params) cts_out = get_primitive_transpose(eqn.primitive)( - params, call_jaxpr, invals, cts_in, cts_in_avals) + params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes) + elif eqn.primitive in reducing_transposes: + cts_out = reducing_transposes[eqn.primitive]( + reduce_axes, cts_in, *invals, **eqn.params) else: cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params) @@ -413,6 +422,8 @@ call_transpose_param_updaters: Dict[core.Primitive, Callable] = {} primitive_jvps : Dict[core.Primitive, Callable] = {} primitive_transposes: Dict[core.Primitive, Callable] = {} +# transpose rules that internally perform reductions over the given named axes +reducing_transposes: Dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): @@ -530,9 +541,9 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents): yield out_flat, tree_def -def call_transpose(primitive, params, call_jaxpr, args, ct, _): +def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts - fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr) + fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) new_params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) @@ -544,7 +555,8 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _): primitive_transposes[core.call_p] = partial(call_transpose, call_p) -def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals): +def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, + cotangent_in_avals, reduce_axes): # backward_pass can only transpose linear computations, but the call_jaxpr embedded in # remat contains primal (non-linear) equations too. Hence, we have to eliminate those # (in this case via partial_eval) before we call into backward_pass again. @@ -558,7 +570,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_ # should all work out, because we're only computing the primal part here. residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):] # Now that we have a purely linear jaxpr, we can transpose it - cotangents_out = backward_pass(tangent_jaxpr.jaxpr, (), primals_in + residuals, cotangents_in) + cotangents_out = backward_pass( + tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in) # backward_pass will return cotangents computed for all invars, but some of them # are residuals appended by partial eval, so we need to skip those before we return. return cotangents_out[:len(primals_in)] @@ -575,9 +588,9 @@ def nonzero_outputs(*args, **kwargs): yield results, [type(r) is not Zero for r in results] -def map_transpose(primitive, params, call_jaxpr, args, ct, _): +def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts - fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr) + fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). diff --git a/tests/api_test.py b/tests/api_test.py index 7604b27b1..4c4228013 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3324,12 +3324,43 @@ class JaxprTest(jtu.JaxTestCase): def f(x): return x - lax.psum(x, 'i') - x = types.SimpleNamespace( + x = api.ShapeDtypeStruct( shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10}) jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x) named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars] self.assertEqual(named_shapes, [{'i': 10}, {}]) + @parameterized.parameters(True, False) + def test_vjp_reduce_axes_jaxpr(self, gy_batched): + def f(w, x): + return jnp.sin(jnp.dot(x, w)) + + w = api.ShapeDtypeStruct( + shape=(3, 4), dtype=jnp.float32, named_shape={}) + x = api.ShapeDtypeStruct( + shape=(3,), dtype=jnp.float32, named_shape={'batch': 2}) + gy = api.ShapeDtypeStruct( + shape=(4,), dtype=jnp.float32, + named_shape={'batch': 2} if gy_batched else {}) + + # per-example + jaxpr, shapes = api.make_jaxpr( + lambda w, x, gy: api.vjp(f, w, x)[1](gy), axis_env=[('batch', 2)], + return_shape=True)(w, x, gy) + expected = (api.ShapeDtypeStruct( + shape=(3, 4), dtype=jnp.float32, named_shape={'batch': 2}), x) + self.assertEqual(shapes, expected) + self.assertNotIn('psum', str(jaxpr)) + + # reduced + jaxpr, shapes = api.make_jaxpr( + lambda w, x, gy: api.vjp(f, w, x, reduce_axes=('batch',))[1](gy), + axis_env=[('batch', 2)], + return_shape=True)(w, x, gy) + expected = (w, x) + self.assertEqual(shapes, expected) + self.assertIn('psum', str(jaxpr)) + class CustomJVPTest(jtu.JaxTestCase): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 4c2a75c2f..8bb073ae6 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -22,7 +22,7 @@ import re import unittest from itertools import product, permutations from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set, - Any, Hashable, Iterable, Iterator, Union) + Any, Hashable, Iterable, Iterator, Union, Optional) from unittest import SkipTest, skip, skipIf import numpy as np @@ -1251,5 +1251,85 @@ class XMapErrorTest(jtu.JaxTestCase): fm(x) +class NamedAutodiffTests(jtu.JaxTestCase): + + def testVjpReduceAxes(self): + def f(w, x): + return jnp.sin(jnp.dot(x, w)) + + def vjp_f(w, x, gy): + _, pullback = jax.vjp(f, w, x) + return pullback(gy) + + def vjp_f_reduced(w, x, gy): + _, pullback = jax.vjp(f, w, x, reduce_axes=('batch',)) + return pullback(gy) + + w = np.arange(12, dtype=np.float32).reshape(3, 4) + x = np.arange(6, dtype=np.float32).reshape(2, 3) + gy = np.arange(8, dtype=np.float32).reshape(2, 4) + + # per-example + error = (r"One of xmap results has an out_axes specification of {}, but is " + r"actually mapped along more axes defined by this xmap call: " + r"batch") + with self.assertRaisesRegex(TypeError, error): + xmap(vjp_f, + in_axes=({}, {0: 'batch'}, {0: 'batch'}), + out_axes=({}, {0: 'batch'}))(w, x, gy) + out = xmap(vjp_f, + in_axes=({}, {0: 'batch'}, {0: 'batch'}), + out_axes=({0: 'batch'}, {0: 'batch'}))(w, x, gy) + expected = vmap(vjp_f, in_axes=(None, 0, 0), out_axes=(0, 0))(w, x, gy) + self.assertAllClose(out, expected, check_dtypes=True) + + # reduced + out = xmap(vjp_f_reduced, + in_axes=({}, {0: 'batch'}, {0: 'batch'}), + out_axes=({}, {0: 'batch'}))(w, x, gy) + # the reduced VJP is also the VJP when using a positional batch axis + expected = vjp_f(w, x, gy) + self.assertAllClose(out, expected, check_dtypes=True) + + def testVjpReduceAxesCollective(self): + + # lax.psum has the wrong transpose, so test with a corrected version for now + @functools.partial(jax.custom_vjp, nondiff_argnums=(1,)) + def psum_idrev(x, axis_name: Optional[AxisNames] = None): + if axis_name is None: + return x + return jax.lax.psum(x, axis_name) + + def psum_idrev_fwd(x, axis_name): + return psum_idrev(x, axis_name), None + + def psum_idrev_bwd(axis_name, res, g): + del axis_name, res + return (g,) + + psum_idrev.defvjp(psum_idrev_fwd, psum_idrev_bwd) + + def f_named(w, x): + return psum_idrev(jnp.sin(jnp.dot(x, w)).sum(), 'batch') + + def f_positional(w, x): + return jnp.sin(jnp.dot(x, w)).sum() + + w = np.arange(12, dtype=np.float32).reshape(3, 4) + x = np.arange(6, dtype=np.float32).reshape(2, 3) + + # forward + out = xmap(f_named, in_axes=({}, {0: 'batch'}), out_axes={})(w, x) + expected = f_positional(w, x) + self.assertAllClose(out, expected, check_dtypes=True) + + # gradient + out = xmap(jax.grad(f_named, (0, 1), reduce_axes=('batch',)), + in_axes=({}, {0: 'batch'}), + out_axes=({}, {0: 'batch'}))(w, x) + expected = jax.grad(f_positional, (0, 1))(w, x) + self.assertAllClose(out, expected, check_dtypes=True) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())