AWN-enabled reduction over named axes in reverse-mode AD

Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.

In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.

If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.

Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.

Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
  - reductions aren't fused into any first-order primitives (e.g. a `pdot`
    should have a named contracting axis added rather than being followed by a
    `psum`; this can be implemented by putting these primitives into
    `reducing_transposes`)
  - reductions are performed eagerly, even over axes that are mapped to
    hardware resources (the optimal thing to do would be to reduce eagerly
    over any vectorized axis component while delaying the reduction over any
    hardware-mapped component until the end of the overall backward pass; this
    would require a way to represent these partially-reduced values)

PiperOrigin-RevId: 383685336
This commit is contained in:
James Bradbury 2021-07-08 12:05:56 -07:00 committed by jax authors
parent f0c30492dc
commit 8e86952ee4
7 changed files with 234 additions and 56 deletions

View File

@ -12,6 +12,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
* Reverse-mode autodiff functions ({func}`jax.grad`,
{func}`jax.value_and_grad`, {func}`jax.vjp`, and
{func}`jax.linear_transpose`) support a parameter that indicates which named
axes should be summed over in the backward pass if they were broadcasted
over in the forward pass. This enables use of these APIs in a
non-per-example way inside maps (initially only
{func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`).
## jaxlib 0.1.69 (unreleased)

View File

@ -578,9 +578,9 @@ def xla_computation(fun: Callable,
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
@ -597,7 +597,7 @@ def xla_computation(fun: Callable,
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
wrapped function returns a pair where the first element is the XLA
Computation and the second element is a pytree representing the structure,
shapes, and dtypes of the output of ``fun``.
shapes, dtypes, and named shapes of the output of ``fun``.
Concrete example arguments are not always necessary. For those arguments not
indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
@ -750,7 +750,8 @@ def xla_computation(fun: Callable,
shapes = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d]
warn(f"Some donated buffers were not usable: {', '.join(shapes)}")
built = c.build(out_tuple)
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, xla.ShapedArray):
@ -767,7 +768,8 @@ def xla_computation(fun: Callable,
def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False) -> Callable:
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
"""Creates a function which evaluates the gradient of ``fun``.
Args:
@ -787,6 +789,13 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
Returns:
A function with the same arguments as ``fun``, that evaluates the gradient
@ -806,7 +815,8 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
"""
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int)
allow_int=allow_int,
reduce_axes=reduce_axes)
docstr = ("Gradient of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
@ -829,7 +839,8 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False) -> Callable[..., Tuple[Any, Any]]:
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
) -> Callable[..., Tuple[Any, Any]]:
"""Create a function which evaluates both ``fun`` and the gradient of ``fun``.
Args:
@ -847,6 +858,14 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``value_and_grad(f, reduce_axes=('batch',))`` will
create a function that computes the total gradient while
``value_and_grad(f)`` will create one that computes the per-example
gradient.
Returns:
A function with the same arguments as ``fun`` that evaluates both ``fun``
@ -879,9 +898,10 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
f_partial, dyn_args = argnums_partial(f, argnums, args)
tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
if not has_aux:
ans, vjp_py = _vjp(f_partial, *dyn_args)
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
else:
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
ans, vjp_py, aux = _vjp(
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
_check_scalar(ans)
dtype = dtypes.result_type(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
@ -1891,12 +1911,14 @@ if sys.version_info >= (3, 8):
@overload # type: ignore
def vjp(fun: Callable[..., T],
*primals: Any,
has_aux: Literal[False] = False) -> Tuple[T, Callable]:
has_aux: Literal[False] = False,
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
...
@overload
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
has_aux: Literal[True]) -> Tuple[T, Callable, U]:
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
...
else:
@ -1907,12 +1929,14 @@ else:
@overload
def vjp(
fun: Callable[..., Any], *primals: Any,
has_aux: bool) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
has_aux: bool,
reduce_axes: Sequence[AxisName] = ()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
...
def vjp( # type: ignore
fun: Callable, *primals, has_aux: bool = False,
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
@ -1929,6 +1953,13 @@ def vjp( # type: ignore
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
VJP will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a VJP function that sums over the batch while ``vjp(f, *args)``
will create a per-example VJP.
Returns:
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
@ -1953,19 +1984,22 @@ def vjp( # type: ignore
-0.2524413
"""
_check_callable(fun)
return _vjp(lu.wrap_init(fun), *primals, has_aux=has_aux)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux, reduce_axes=reduce_axes)
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
"""Variant of vjp() that takes an lu.WrappedFun."""
primals_flat, in_tree = tree_flatten(primals)
for arg in primals_flat: _check_arg(arg)
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
out_primal, out_vjp = ad.vjp(
flat_fun, primals_flat, reduce_axes=reduce_axes)
out_tree = out_tree()
else:
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
out_primal, out_vjp, aux = ad.vjp(
flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes)
out_tree, aux_tree = out_aux_trees()
out_primal_py = tree_unflatten(out_tree, out_primal)
ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal]
@ -1981,7 +2015,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
def linear_transpose(fun: Callable, *primals) -> Callable:
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
"""Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to ``vjp``, but
@ -2002,6 +2036,14 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
is not required: only the ``shape`` and ``dtype`` attributes are accessed.
See below for an example. (Note that the duck-typed objects cannot be
namedtuples because those are treated as standard Python containers.)
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding cotangent. Otherwise, the
transposed function will be per-example over named axes. For example, if
``'batch'`` is a named batch axis, ``linear_transpose(f, *args,
reduce_axes=('batch',))`` will create a transpose function that sums over
the batch while ``linear_transpose(f, args)`` will create a per-example
transpose.
Returns:
A callable that calculates the transpose of ``fun``. Valid input into this
@ -2046,7 +2088,7 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cotangents = map(
ad.instantiate_zeros,
ad.backward_pass(jaxpr, consts, dummies, out_cotangents))
ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents))
return tree_unflatten(in_tree, in_cotangents)
return transposed_fun
@ -2071,11 +2113,11 @@ def make_jaxpr(fun: Callable,
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the ``jaxpr``
and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
@ -2083,7 +2125,7 @@ def make_jaxpr(fun: Callable,
argument ``return_shape`` is ``True``, then the returned function instead
returns a pair where the first element is the ``ClosedJaxpr``
representation of ``fun`` and the second element is a pytree representing
the structure, shape, and dtypes of the output of ``fun``.
the structure, shape, dtypes, and named shapes of the output of ``fun``.
A ``jaxpr`` is JAX's intermediate representation for program traces. The
``jaxpr`` language is based on the simply-typed first-order lambda calculus
@ -2135,7 +2177,8 @@ def make_jaxpr(fun: Callable,
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
return closed_jaxpr

View File

@ -369,11 +369,12 @@ xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
num_consts):
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
jvp_jaxpr_thunk, num_consts):
del jvp_jaxpr_thunk, num_consts
return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)
ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
return ad.backward_pass(
fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts)
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
### VJPs

View File

@ -1020,7 +1020,7 @@ def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())
def _transpose_cond_jaxpr(jaxpr, num_res):
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = _map(raise_to_shaped, primal_avals)
@ -1029,19 +1029,19 @@ def _transpose_cond_jaxpr(jaxpr, num_res):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
cts_in = ad.backward_pass(
jaxpr.jaxpr, jaxpr.consts, primals, cts_out)
jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
def _cond_transpose(cts, *args, branches, linear):
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
index, *ops = args
in_avals = _map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)
branches_trans = tuple(
_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
lin_in_avals = [raise_to_shaped(a, weak_type=False)
for a, l in zip(in_avals, linear) if l]
assert all(core.typematch(out_aval, lin_in_aval)
@ -1131,7 +1131,7 @@ cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.primitive_transposes[cond_p] = _cond_transpose
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.initial_style_batchers[cond_p] = _cond_batching_rule
xla.initial_style_translations[cond_p] = _cond_translation_rule
@ -1673,8 +1673,8 @@ def _maybe_device_put(x):
else:
return x
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
@ -1702,7 +1702,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
jaxpr_trans = _transpose_scan_jaxpr(
num_ires, num_consts - num_ires, num_eres, jaxpr)
num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes)
linear_trans = ([False] * num_ires +
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
[False] * num_eres)
@ -1717,8 +1717,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
# if an axis isn't reduced
res1_avals, c_avals, a_avals, res2_avals = split_list(
jaxpr.in_avals, [num_res1, num_c, num_a])
num_b = len(jaxpr.out_avals)
@ -1730,7 +1732,8 @@ def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts,
primals, b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
@ -1879,7 +1882,7 @@ scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
scan_p.def_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.primitive_transposes[scan_p] = _scan_transpose
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl)
batching.initial_style_batchers[scan_p] = _scan_batching_rule

View File

@ -17,13 +17,14 @@ import functools
import itertools as it
from typing import Any, Callable, Dict
import jax
from . import partial_eval as pe
from ..config import config
from .. import core
from .._src.dtypes import dtype, float0
from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
from .._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_aval, zeros_like_p, Zero)
from .._src.util import (unzip2, safe_map, safe_zip, partial, split_list,
wrap_name, as_hashable_function)
@ -31,7 +32,7 @@ from ..tree_util import register_pytree_node
from .. import linear_util as lu
from ..api_util import flatten_fun, flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten, Partial
from jax._src import source_info_util
from .._src import source_info_util
zip = safe_zip
map = safe_map
@ -109,7 +110,7 @@ def linearize(traceable, *primals, **kwargs):
else:
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
def vjp(traceable, primals, has_aux=False):
def vjp(traceable, primals, has_aux=False, reduce_axes=()):
if not has_aux:
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
else:
@ -118,7 +119,7 @@ def vjp(traceable, primals, has_aux=False):
def unbound_vjp(pvals, jaxpr, consts, *cts):
cts = tuple(map(ignore_consts, cts, pvals))
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
return map(instantiate_zeros, arg_cts)
# Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
@ -160,7 +161,7 @@ def recast_to_float0(primal, tangent):
return tangent
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents_in):
if all(type(ct) is Zero for ct in cotangents_in):
return map(lambda v: Zero(v.aval), jaxpr.invars)
@ -173,6 +174,11 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
# FIXME: This triggers a lot of failures!
# assert v.aval == ct.aval, (prim, v.aval, ct.aval)
return
axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
if axis_name in core.get_aval(ct).named_shape
and axis_name not in v.aval.named_shape)
if axes_to_reduce:
ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
if config.jax_enable_checks:
ct_aval = core.get_aval(ct_env[v])
@ -213,7 +219,10 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
cts_in_avals = [v.aval for v in eqn.outvars]
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in, cts_in_avals)
params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
elif eqn.primitive in reducing_transposes:
cts_out = reducing_transposes[eqn.primitive](
reduce_axes, cts_in, *invals, **eqn.params)
else:
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
**eqn.params)
@ -413,6 +422,8 @@ call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
primitive_jvps : Dict[core.Primitive, Callable] = {}
primitive_transposes: Dict[core.Primitive, Callable] = {}
# transpose rules that internally perform reductions over the given named axes
reducing_transposes: Dict[core.Primitive, Callable] = {}
def deflinear(primitive, transpose_rule):
@ -530,9 +541,9 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
yield out_flat, tree_def
def call_transpose(primitive, params, call_jaxpr, args, ct, _):
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
update_params = call_transpose_param_updaters.get(primitive)
@ -544,7 +555,8 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _):
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals):
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
cotangent_in_avals, reduce_axes):
# backward_pass can only transpose linear computations, but the call_jaxpr embedded in
# remat contains primal (non-linear) equations too. Hence, we have to eliminate those
# (in this case via partial_eval) before we call into backward_pass again.
@ -558,7 +570,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
# should all work out, because we're only computing the primal part here.
residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
# Now that we have a purely linear jaxpr, we can transpose it
cotangents_out = backward_pass(tangent_jaxpr.jaxpr, (), primals_in + residuals, cotangents_in)
cotangents_out = backward_pass(
tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in)
# backward_pass will return cotangents computed for all invars, but some of them
# are residuals appended by partial eval, so we need to skip those before we return.
return cotangents_out[:len(primals_in)]
@ -575,9 +588,9 @@ def nonzero_outputs(*args, **kwargs):
yield results, [type(r) is not Zero for r in results]
def map_transpose(primitive, params, call_jaxpr, args, ct, _):
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
fun, nz_arg_cts = nonzero_outputs(fun)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).

View File

@ -3324,12 +3324,43 @@ class JaxprTest(jtu.JaxTestCase):
def f(x):
return x - lax.psum(x, 'i')
x = types.SimpleNamespace(
x = api.ShapeDtypeStruct(
shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10})
jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x)
named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars]
self.assertEqual(named_shapes, [{'i': 10}, {}])
@parameterized.parameters(True, False)
def test_vjp_reduce_axes_jaxpr(self, gy_batched):
def f(w, x):
return jnp.sin(jnp.dot(x, w))
w = api.ShapeDtypeStruct(
shape=(3, 4), dtype=jnp.float32, named_shape={})
x = api.ShapeDtypeStruct(
shape=(3,), dtype=jnp.float32, named_shape={'batch': 2})
gy = api.ShapeDtypeStruct(
shape=(4,), dtype=jnp.float32,
named_shape={'batch': 2} if gy_batched else {})
# per-example
jaxpr, shapes = api.make_jaxpr(
lambda w, x, gy: api.vjp(f, w, x)[1](gy), axis_env=[('batch', 2)],
return_shape=True)(w, x, gy)
expected = (api.ShapeDtypeStruct(
shape=(3, 4), dtype=jnp.float32, named_shape={'batch': 2}), x)
self.assertEqual(shapes, expected)
self.assertNotIn('psum', str(jaxpr))
# reduced
jaxpr, shapes = api.make_jaxpr(
lambda w, x, gy: api.vjp(f, w, x, reduce_axes=('batch',))[1](gy),
axis_env=[('batch', 2)],
return_shape=True)(w, x, gy)
expected = (w, x)
self.assertEqual(shapes, expected)
self.assertIn('psum', str(jaxpr))
class CustomJVPTest(jtu.JaxTestCase):

View File

@ -22,7 +22,7 @@ import re
import unittest
from itertools import product, permutations
from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set,
Any, Hashable, Iterable, Iterator, Union)
Any, Hashable, Iterable, Iterator, Union, Optional)
from unittest import SkipTest, skip, skipIf
import numpy as np
@ -1251,5 +1251,85 @@ class XMapErrorTest(jtu.JaxTestCase):
fm(x)
class NamedAutodiffTests(jtu.JaxTestCase):
def testVjpReduceAxes(self):
def f(w, x):
return jnp.sin(jnp.dot(x, w))
def vjp_f(w, x, gy):
_, pullback = jax.vjp(f, w, x)
return pullback(gy)
def vjp_f_reduced(w, x, gy):
_, pullback = jax.vjp(f, w, x, reduce_axes=('batch',))
return pullback(gy)
w = np.arange(12, dtype=np.float32).reshape(3, 4)
x = np.arange(6, dtype=np.float32).reshape(2, 3)
gy = np.arange(8, dtype=np.float32).reshape(2, 4)
# per-example
error = (r"One of xmap results has an out_axes specification of {}, but is "
r"actually mapped along more axes defined by this xmap call: "
r"batch")
with self.assertRaisesRegex(TypeError, error):
xmap(vjp_f,
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
out_axes=({}, {0: 'batch'}))(w, x, gy)
out = xmap(vjp_f,
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
out_axes=({0: 'batch'}, {0: 'batch'}))(w, x, gy)
expected = vmap(vjp_f, in_axes=(None, 0, 0), out_axes=(0, 0))(w, x, gy)
self.assertAllClose(out, expected, check_dtypes=True)
# reduced
out = xmap(vjp_f_reduced,
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
out_axes=({}, {0: 'batch'}))(w, x, gy)
# the reduced VJP is also the VJP when using a positional batch axis
expected = vjp_f(w, x, gy)
self.assertAllClose(out, expected, check_dtypes=True)
def testVjpReduceAxesCollective(self):
# lax.psum has the wrong transpose, so test with a corrected version for now
@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def psum_idrev(x, axis_name: Optional[AxisNames] = None):
if axis_name is None:
return x
return jax.lax.psum(x, axis_name)
def psum_idrev_fwd(x, axis_name):
return psum_idrev(x, axis_name), None
def psum_idrev_bwd(axis_name, res, g):
del axis_name, res
return (g,)
psum_idrev.defvjp(psum_idrev_fwd, psum_idrev_bwd)
def f_named(w, x):
return psum_idrev(jnp.sin(jnp.dot(x, w)).sum(), 'batch')
def f_positional(w, x):
return jnp.sin(jnp.dot(x, w)).sum()
w = np.arange(12, dtype=np.float32).reshape(3, 4)
x = np.arange(6, dtype=np.float32).reshape(2, 3)
# forward
out = xmap(f_named, in_axes=({}, {0: 'batch'}), out_axes={})(w, x)
expected = f_positional(w, x)
self.assertAllClose(out, expected, check_dtypes=True)
# gradient
out = xmap(jax.grad(f_named, (0, 1), reduce_axes=('batch',)),
in_axes=({}, {0: 'batch'}),
out_axes=({}, {0: 'batch'}))(w, x)
expected = jax.grad(f_positional, (0, 1))(w, x)
self.assertAllClose(out, expected, check_dtypes=True)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())