mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute a gradient (or VJP, etc.) for each axis index in the map". For instance, `vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`. In batching tracer terms, this "elementwise" behavior means that, if any inputs to a function being transposed are mapped, the cotangents of all inputs, even unmapped ones, would also be mapped. But a user might want them to be unmapped (if, for instance, they're interested in a total gradient rather than a per-example gradient). They could always reduce (`psum`) the cotangents afterwards, but computing mapped cotangents in the first place would likely be an unacceptable waste of memory and can't necessarily be optimized away. If we want to fuse these reductions into reverse-mode autodiff itself, we need the backward_pass logic and/or transpose rules to know about whether primal values are mapped or unmapped. This is made possible by avals-with-names, which encodes that information in the avals of the primal jaxpr. Putting things together, **this change adds an option to reverse-mode AD APIs that indicates which named axes should be reduced over in the backward pass in situations where they were broadcasted over in the forward pass**. All other named axes will be treated in the current elementwise way. This has the effect of making APIs like `grad` behave akin to collectives like `psum`: they act collectively over axes that are named explicitly, and elementwise otherwise. Since avals-with-names is currently enabled only in `xmap`, this behavior is only available in that context for now. It's also missing some optimizations: - reductions aren't fused into any first-order primitives (e.g. a `pdot` should have a named contracting axis added rather than being followed by a `psum`; this can be implemented by putting these primitives into `reducing_transposes`) - reductions are performed eagerly, even over axes that are mapped to hardware resources (the optimal thing to do would be to reduce eagerly over any vectorized axis component while delaying the reduction over any hardware-mapped component until the end of the overall backward pass; this would require a way to represent these partially-reduced values) PiperOrigin-RevId: 383685336
This commit is contained in:
parent
f0c30492dc
commit
8e86952ee4
@ -12,6 +12,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
|
||||
* New features:
|
||||
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
|
||||
* Reverse-mode autodiff functions ({func}`jax.grad`,
|
||||
{func}`jax.value_and_grad`, {func}`jax.vjp`, and
|
||||
{func}`jax.linear_transpose`) support a parameter that indicates which named
|
||||
axes should be summed over in the backward pass if they were broadcasted
|
||||
over in the forward pass. This enables use of these APIs in a
|
||||
non-per-example way inside maps (initially only
|
||||
{func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`).
|
||||
|
||||
## jaxlib 0.1.69 (unreleased)
|
||||
|
||||
|
@ -578,9 +578,9 @@ def xla_computation(fun: Callable,
|
||||
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
||||
wrapped function returns a pair where the first element is the XLA
|
||||
computation and the second element is a pytree with the same structure as
|
||||
the output of ``fun`` and where the leaves are objects with ``shape`` and
|
||||
``dtype`` attributes representing the corresponding types of the output
|
||||
leaves.
|
||||
the output of ``fun`` and where the leaves are objects with ``shape``,
|
||||
``dtype``, and ``named_shape`` attributes representing the corresponding
|
||||
types of the output leaves.
|
||||
donate_argnums: Specify which arguments are "donated" to the computation.
|
||||
It is safe to donate arguments if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
@ -597,7 +597,7 @@ def xla_computation(fun: Callable,
|
||||
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
|
||||
wrapped function returns a pair where the first element is the XLA
|
||||
Computation and the second element is a pytree representing the structure,
|
||||
shapes, and dtypes of the output of ``fun``.
|
||||
shapes, dtypes, and named shapes of the output of ``fun``.
|
||||
|
||||
Concrete example arguments are not always necessary. For those arguments not
|
||||
indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
|
||||
@ -750,7 +750,8 @@ def xla_computation(fun: Callable,
|
||||
shapes = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d]
|
||||
warn(f"Some donated buffers were not usable: {', '.join(shapes)}")
|
||||
built = c.build(out_tuple)
|
||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
||||
out_shapes_flat = [
|
||||
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
||||
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
|
||||
for out_aval in out_avals:
|
||||
if not isinstance(out_aval, xla.ShapedArray):
|
||||
@ -767,7 +768,8 @@ def xla_computation(fun: Callable,
|
||||
|
||||
def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
has_aux: bool = False, holomorphic: bool = False,
|
||||
allow_int: bool = False) -> Callable:
|
||||
allow_int: bool = False,
|
||||
reduce_axes: Sequence[AxisName] = ()) -> Callable:
|
||||
"""Creates a function which evaluates the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -787,6 +789,13 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
respect to integer valued inputs. The gradient of an integer input will
|
||||
have a trivial vector-space dtype (float0). Default False.
|
||||
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
|
||||
``fun`` implicitly broadcasts a value over that axis, the backward pass
|
||||
will perform a ``psum`` of the corresponding gradient. Otherwise, the
|
||||
gradient will be per-example over named axes. For example, if ``'batch'``
|
||||
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
|
||||
function that computes the total gradient while ``grad(f)`` will create
|
||||
one that computes the per-example gradient.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as ``fun``, that evaluates the gradient
|
||||
@ -806,7 +815,8 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
"""
|
||||
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
|
||||
holomorphic=holomorphic,
|
||||
allow_int=allow_int)
|
||||
allow_int=allow_int,
|
||||
reduce_axes=reduce_axes)
|
||||
|
||||
docstr = ("Gradient of {fun} with respect to positional argument(s) "
|
||||
"{argnums}. Takes the same arguments as {fun} but returns the "
|
||||
@ -829,7 +839,8 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
|
||||
def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
has_aux: bool = False, holomorphic: bool = False,
|
||||
allow_int: bool = False) -> Callable[..., Tuple[Any, Any]]:
|
||||
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
|
||||
) -> Callable[..., Tuple[Any, Any]]:
|
||||
"""Create a function which evaluates both ``fun`` and the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
@ -847,6 +858,14 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
allow_int: Optional, bool. Whether to allow differentiating with
|
||||
respect to integer valued inputs. The gradient of an integer input will
|
||||
have a trivial vector-space dtype (float0). Default False.
|
||||
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
|
||||
``fun`` implicitly broadcasts a value over that axis, the backward pass
|
||||
will perform a ``psum`` of the corresponding gradient. Otherwise, the
|
||||
gradient will be per-example over named axes. For example, if ``'batch'``
|
||||
is a named batch axis, ``value_and_grad(f, reduce_axes=('batch',))`` will
|
||||
create a function that computes the total gradient while
|
||||
``value_and_grad(f)`` will create one that computes the per-example
|
||||
gradient.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as ``fun`` that evaluates both ``fun``
|
||||
@ -879,9 +898,10 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
|
||||
if not has_aux:
|
||||
ans, vjp_py = _vjp(f_partial, *dyn_args)
|
||||
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
|
||||
else:
|
||||
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
|
||||
ans, vjp_py, aux = _vjp(
|
||||
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
|
||||
_check_scalar(ans)
|
||||
dtype = dtypes.result_type(ans)
|
||||
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
|
||||
@ -1891,12 +1911,14 @@ if sys.version_info >= (3, 8):
|
||||
@overload # type: ignore
|
||||
def vjp(fun: Callable[..., T],
|
||||
*primals: Any,
|
||||
has_aux: Literal[False] = False) -> Tuple[T, Callable]:
|
||||
has_aux: Literal[False] = False,
|
||||
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
|
||||
has_aux: Literal[True]) -> Tuple[T, Callable, U]:
|
||||
has_aux: Literal[True],
|
||||
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
|
||||
...
|
||||
else:
|
||||
|
||||
@ -1907,12 +1929,14 @@ else:
|
||||
@overload
|
||||
def vjp(
|
||||
fun: Callable[..., Any], *primals: Any,
|
||||
has_aux: bool) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
|
||||
has_aux: bool,
|
||||
reduce_axes: Sequence[AxisName] = ()
|
||||
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
|
||||
...
|
||||
|
||||
|
||||
def vjp( # type: ignore
|
||||
fun: Callable, *primals, has_aux: bool = False,
|
||||
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
|
||||
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
|
||||
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
|
||||
|
||||
@ -1929,6 +1953,13 @@ def vjp( # type: ignore
|
||||
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
|
||||
``fun`` implicitly broadcasts a value over that axis, the backward pass
|
||||
will perform a ``psum`` of the corresponding gradient. Otherwise, the
|
||||
VJP will be per-example over named axes. For example, if ``'batch'``
|
||||
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
|
||||
create a VJP function that sums over the batch while ``vjp(f, *args)``
|
||||
will create a per-example VJP.
|
||||
|
||||
Returns:
|
||||
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
|
||||
@ -1953,19 +1984,22 @@ def vjp( # type: ignore
|
||||
-0.2524413
|
||||
"""
|
||||
_check_callable(fun)
|
||||
return _vjp(lu.wrap_init(fun), *primals, has_aux=has_aux)
|
||||
return _vjp(
|
||||
lu.wrap_init(fun), *primals, has_aux=has_aux, reduce_axes=reduce_axes)
|
||||
|
||||
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
|
||||
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
|
||||
"""Variant of vjp() that takes an lu.WrappedFun."""
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
for arg in primals_flat: _check_arg(arg)
|
||||
if not has_aux:
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
||||
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
|
||||
out_primal, out_vjp = ad.vjp(
|
||||
flat_fun, primals_flat, reduce_axes=reduce_axes)
|
||||
out_tree = out_tree()
|
||||
else:
|
||||
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
|
||||
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
|
||||
out_primal, out_vjp, aux = ad.vjp(
|
||||
flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes)
|
||||
out_tree, aux_tree = out_aux_trees()
|
||||
out_primal_py = tree_unflatten(out_tree, out_primal)
|
||||
ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal]
|
||||
@ -1981,7 +2015,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
|
||||
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
|
||||
|
||||
|
||||
def linear_transpose(fun: Callable, *primals) -> Callable:
|
||||
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
"""Transpose a function that is promised to be linear.
|
||||
|
||||
For linear functions, this transformation is equivalent to ``vjp``, but
|
||||
@ -2002,6 +2036,14 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
|
||||
is not required: only the ``shape`` and ``dtype`` attributes are accessed.
|
||||
See below for an example. (Note that the duck-typed objects cannot be
|
||||
namedtuples because those are treated as standard Python containers.)
|
||||
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
|
||||
``fun`` implicitly broadcasts a value over that axis, the backward pass
|
||||
will perform a ``psum`` of the corresponding cotangent. Otherwise, the
|
||||
transposed function will be per-example over named axes. For example, if
|
||||
``'batch'`` is a named batch axis, ``linear_transpose(f, *args,
|
||||
reduce_axes=('batch',))`` will create a transpose function that sums over
|
||||
the batch while ``linear_transpose(f, args)`` will create a per-example
|
||||
transpose.
|
||||
|
||||
Returns:
|
||||
A callable that calculates the transpose of ``fun``. Valid input into this
|
||||
@ -2046,7 +2088,7 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
|
||||
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
|
||||
in_cotangents = map(
|
||||
ad.instantiate_zeros,
|
||||
ad.backward_pass(jaxpr, consts, dummies, out_cotangents))
|
||||
ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents))
|
||||
return tree_unflatten(in_tree, in_cotangents)
|
||||
|
||||
return transposed_fun
|
||||
@ -2071,11 +2113,11 @@ def make_jaxpr(fun: Callable,
|
||||
specifies the axis name/size environment that would be set up by
|
||||
applications of :py:func:`jax.pmap`.
|
||||
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
||||
wrapped function returns a pair where the first element is the ``jaxpr``
|
||||
and the second element is a pytree with the same structure as
|
||||
the output of ``fun`` and where the leaves are objects with ``shape`` and
|
||||
``dtype`` attributes representing the corresponding types of the output
|
||||
leaves.
|
||||
wrapped function returns a pair where the first element is the XLA
|
||||
computation and the second element is a pytree with the same structure as
|
||||
the output of ``fun`` and where the leaves are objects with ``shape``,
|
||||
``dtype``, and ``named_shape`` attributes representing the corresponding
|
||||
types of the output leaves.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun`` that when applied to example arguments returns
|
||||
@ -2083,7 +2125,7 @@ def make_jaxpr(fun: Callable,
|
||||
argument ``return_shape`` is ``True``, then the returned function instead
|
||||
returns a pair where the first element is the ``ClosedJaxpr``
|
||||
representation of ``fun`` and the second element is a pytree representing
|
||||
the structure, shape, and dtypes of the output of ``fun``.
|
||||
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
||||
|
||||
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
||||
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
||||
@ -2135,7 +2177,8 @@ def make_jaxpr(fun: Callable,
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
if return_shape:
|
||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
||||
out_shapes_flat = [
|
||||
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
||||
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
|
||||
return closed_jaxpr
|
||||
|
||||
|
@ -369,11 +369,12 @@ xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
|
||||
# If a (multi)linear function is defined with a custom jvp, then
|
||||
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
|
||||
# already been linearized, we can drop the jvp rule.
|
||||
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
|
||||
num_consts):
|
||||
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
|
||||
jvp_jaxpr_thunk, num_consts):
|
||||
del jvp_jaxpr_thunk, num_consts
|
||||
return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)
|
||||
ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
|
||||
return ad.backward_pass(
|
||||
fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts)
|
||||
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
|
||||
|
||||
|
||||
### VJPs
|
||||
|
@ -1020,7 +1020,7 @@ def _ordered_unique(xs):
|
||||
d = collections.OrderedDict((x, None) for x in xs)
|
||||
return list(d.keys())
|
||||
|
||||
def _transpose_cond_jaxpr(jaxpr, num_res):
|
||||
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
|
||||
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
|
||||
primal_avals = _map(raise_to_shaped, primal_avals)
|
||||
|
||||
@ -1029,19 +1029,19 @@ def _transpose_cond_jaxpr(jaxpr, num_res):
|
||||
res, cts_out = split_list(args, [num_res])
|
||||
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
|
||||
cts_in = ad.backward_pass(
|
||||
jaxpr.jaxpr, jaxpr.consts, primals, cts_out)
|
||||
jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out)
|
||||
_, cts_in = split_list(cts_in, [num_res])
|
||||
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
|
||||
|
||||
return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)
|
||||
|
||||
def _cond_transpose(cts, *args, branches, linear):
|
||||
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
|
||||
index, *ops = args
|
||||
in_avals = _map(raise_to_shaped, branches[0].in_avals)
|
||||
num_res = len(ops) - sum(linear)
|
||||
|
||||
branches_trans = tuple(
|
||||
_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
|
||||
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
|
||||
lin_in_avals = [raise_to_shaped(a, weak_type=False)
|
||||
for a, l in zip(in_avals, linear) if l]
|
||||
assert all(core.typematch(out_aval, lin_in_aval)
|
||||
@ -1131,7 +1131,7 @@ cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
||||
cond_p.def_abstract_eval(_cond_abstract_eval)
|
||||
cond_p.def_custom_bind(cond_bind)
|
||||
ad.primitive_jvps[cond_p] = _cond_jvp
|
||||
ad.primitive_transposes[cond_p] = _cond_transpose
|
||||
ad.reducing_transposes[cond_p] = _cond_transpose
|
||||
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
||||
batching.initial_style_batchers[cond_p] = _cond_batching_rule
|
||||
xla.initial_style_translations[cond_p] = _cond_translation_rule
|
||||
@ -1673,8 +1673,8 @@ def _maybe_device_put(x):
|
||||
else:
|
||||
return x
|
||||
|
||||
def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll):
|
||||
def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,
|
||||
num_carry, jaxpr, linear, unroll):
|
||||
# we've only implemented transposing scans with specific lin/nonlin patterns
|
||||
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
|
||||
num_ires = len(consts_lin) - sum(consts_lin)
|
||||
@ -1702,7 +1702,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
|
||||
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
|
||||
jaxpr_trans = _transpose_scan_jaxpr(
|
||||
num_ires, num_consts - num_ires, num_eres, jaxpr)
|
||||
num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes)
|
||||
linear_trans = ([False] * num_ires +
|
||||
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
|
||||
[False] * num_eres)
|
||||
@ -1717,8 +1717,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
|
||||
|
||||
# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
|
||||
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
|
||||
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
|
||||
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
|
||||
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
|
||||
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
|
||||
# if an axis isn't reduced
|
||||
res1_avals, c_avals, a_avals, res2_avals = split_list(
|
||||
jaxpr.in_avals, [num_res1, num_c, num_a])
|
||||
num_b = len(jaxpr.out_avals)
|
||||
@ -1730,7 +1732,8 @@ def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
|
||||
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
|
||||
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
|
||||
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
|
||||
cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
|
||||
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts,
|
||||
primals, b_bar)
|
||||
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
|
||||
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
|
||||
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
|
||||
@ -1879,7 +1882,7 @@ scan_p.def_custom_bind(scan_bind)
|
||||
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
|
||||
scan_p.def_abstract_eval(_scan_abstract_eval)
|
||||
ad.primitive_jvps[scan_p] = _scan_jvp
|
||||
ad.primitive_transposes[scan_p] = _scan_transpose
|
||||
ad.reducing_transposes[scan_p] = _scan_transpose
|
||||
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
||||
xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl)
|
||||
batching.initial_style_batchers[scan_p] = _scan_batching_rule
|
||||
|
@ -17,13 +17,14 @@ import functools
|
||||
import itertools as it
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import jax
|
||||
from . import partial_eval as pe
|
||||
from ..config import config
|
||||
from .. import core
|
||||
from .._src.dtypes import dtype, float0
|
||||
from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
from .._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_aval, zeros_like_p, Zero)
|
||||
from .._src.util import (unzip2, safe_map, safe_zip, partial, split_list,
|
||||
wrap_name, as_hashable_function)
|
||||
@ -31,7 +32,7 @@ from ..tree_util import register_pytree_node
|
||||
from .. import linear_util as lu
|
||||
from ..api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten, Partial
|
||||
from jax._src import source_info_util
|
||||
from .._src import source_info_util
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
@ -109,7 +110,7 @@ def linearize(traceable, *primals, **kwargs):
|
||||
else:
|
||||
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
|
||||
|
||||
def vjp(traceable, primals, has_aux=False):
|
||||
def vjp(traceable, primals, has_aux=False, reduce_axes=()):
|
||||
if not has_aux:
|
||||
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
|
||||
else:
|
||||
@ -118,7 +119,7 @@ def vjp(traceable, primals, has_aux=False):
|
||||
def unbound_vjp(pvals, jaxpr, consts, *cts):
|
||||
cts = tuple(map(ignore_consts, cts, pvals))
|
||||
dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
|
||||
arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
|
||||
arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
|
||||
return map(instantiate_zeros, arg_cts)
|
||||
|
||||
# Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
|
||||
@ -160,7 +161,7 @@ def recast_to_float0(primal, tangent):
|
||||
return tangent
|
||||
|
||||
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
|
||||
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
||||
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, consts, primals_in, cotangents_in):
|
||||
if all(type(ct) is Zero for ct in cotangents_in):
|
||||
return map(lambda v: Zero(v.aval), jaxpr.invars)
|
||||
|
||||
@ -173,6 +174,11 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
||||
# FIXME: This triggers a lot of failures!
|
||||
# assert v.aval == ct.aval, (prim, v.aval, ct.aval)
|
||||
return
|
||||
axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
|
||||
if axis_name in core.get_aval(ct).named_shape
|
||||
and axis_name not in v.aval.named_shape)
|
||||
if axes_to_reduce:
|
||||
ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
|
||||
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
||||
if config.jax_enable_checks:
|
||||
ct_aval = core.get_aval(ct_env[v])
|
||||
@ -213,7 +219,10 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
||||
cts_in_avals = [v.aval for v in eqn.outvars]
|
||||
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
params, call_jaxpr, invals, cts_in, cts_in_avals)
|
||||
params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
|
||||
elif eqn.primitive in reducing_transposes:
|
||||
cts_out = reducing_transposes[eqn.primitive](
|
||||
reduce_axes, cts_in, *invals, **eqn.params)
|
||||
else:
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
|
||||
**eqn.params)
|
||||
@ -413,6 +422,8 @@ call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
primitive_jvps : Dict[core.Primitive, Callable] = {}
|
||||
|
||||
primitive_transposes: Dict[core.Primitive, Callable] = {}
|
||||
# transpose rules that internally perform reductions over the given named axes
|
||||
reducing_transposes: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
def deflinear(primitive, transpose_rule):
|
||||
@ -530,9 +541,9 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
|
||||
yield out_flat, tree_def
|
||||
|
||||
|
||||
def call_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
|
||||
update_params = call_transpose_param_updaters.get(primitive)
|
||||
@ -544,7 +555,8 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
|
||||
|
||||
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals):
|
||||
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
|
||||
cotangent_in_avals, reduce_axes):
|
||||
# backward_pass can only transpose linear computations, but the call_jaxpr embedded in
|
||||
# remat contains primal (non-linear) equations too. Hence, we have to eliminate those
|
||||
# (in this case via partial_eval) before we call into backward_pass again.
|
||||
@ -558,7 +570,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
|
||||
# should all work out, because we're only computing the primal part here.
|
||||
residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
|
||||
# Now that we have a purely linear jaxpr, we can transpose it
|
||||
cotangents_out = backward_pass(tangent_jaxpr.jaxpr, (), primals_in + residuals, cotangents_in)
|
||||
cotangents_out = backward_pass(
|
||||
tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in)
|
||||
# backward_pass will return cotangents computed for all invars, but some of them
|
||||
# are residuals appended by partial eval, so we need to skip those before we return.
|
||||
return cotangents_out[:len(primals_in)]
|
||||
@ -575,9 +588,9 @@ def nonzero_outputs(*args, **kwargs):
|
||||
yield results, [type(r) is not Zero for r in results]
|
||||
|
||||
|
||||
def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes)
|
||||
fun, nz_arg_cts = nonzero_outputs(fun)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
# Preserve axis for primal arguments, skip tangents (represented as undefined primals).
|
||||
|
@ -3324,12 +3324,43 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
|
||||
x = types.SimpleNamespace(
|
||||
x = api.ShapeDtypeStruct(
|
||||
shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10})
|
||||
jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x)
|
||||
named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars]
|
||||
self.assertEqual(named_shapes, [{'i': 10}, {}])
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
def test_vjp_reduce_axes_jaxpr(self, gy_batched):
|
||||
def f(w, x):
|
||||
return jnp.sin(jnp.dot(x, w))
|
||||
|
||||
w = api.ShapeDtypeStruct(
|
||||
shape=(3, 4), dtype=jnp.float32, named_shape={})
|
||||
x = api.ShapeDtypeStruct(
|
||||
shape=(3,), dtype=jnp.float32, named_shape={'batch': 2})
|
||||
gy = api.ShapeDtypeStruct(
|
||||
shape=(4,), dtype=jnp.float32,
|
||||
named_shape={'batch': 2} if gy_batched else {})
|
||||
|
||||
# per-example
|
||||
jaxpr, shapes = api.make_jaxpr(
|
||||
lambda w, x, gy: api.vjp(f, w, x)[1](gy), axis_env=[('batch', 2)],
|
||||
return_shape=True)(w, x, gy)
|
||||
expected = (api.ShapeDtypeStruct(
|
||||
shape=(3, 4), dtype=jnp.float32, named_shape={'batch': 2}), x)
|
||||
self.assertEqual(shapes, expected)
|
||||
self.assertNotIn('psum', str(jaxpr))
|
||||
|
||||
# reduced
|
||||
jaxpr, shapes = api.make_jaxpr(
|
||||
lambda w, x, gy: api.vjp(f, w, x, reduce_axes=('batch',))[1](gy),
|
||||
axis_env=[('batch', 2)],
|
||||
return_shape=True)(w, x, gy)
|
||||
expected = (w, x)
|
||||
self.assertEqual(shapes, expected)
|
||||
self.assertIn('psum', str(jaxpr))
|
||||
|
||||
|
||||
class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -22,7 +22,7 @@ import re
|
||||
import unittest
|
||||
from itertools import product, permutations
|
||||
from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set,
|
||||
Any, Hashable, Iterable, Iterator, Union)
|
||||
Any, Hashable, Iterable, Iterator, Union, Optional)
|
||||
from unittest import SkipTest, skip, skipIf
|
||||
|
||||
import numpy as np
|
||||
@ -1251,5 +1251,85 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
fm(x)
|
||||
|
||||
|
||||
class NamedAutodiffTests(jtu.JaxTestCase):
|
||||
|
||||
def testVjpReduceAxes(self):
|
||||
def f(w, x):
|
||||
return jnp.sin(jnp.dot(x, w))
|
||||
|
||||
def vjp_f(w, x, gy):
|
||||
_, pullback = jax.vjp(f, w, x)
|
||||
return pullback(gy)
|
||||
|
||||
def vjp_f_reduced(w, x, gy):
|
||||
_, pullback = jax.vjp(f, w, x, reduce_axes=('batch',))
|
||||
return pullback(gy)
|
||||
|
||||
w = np.arange(12, dtype=np.float32).reshape(3, 4)
|
||||
x = np.arange(6, dtype=np.float32).reshape(2, 3)
|
||||
gy = np.arange(8, dtype=np.float32).reshape(2, 4)
|
||||
|
||||
# per-example
|
||||
error = (r"One of xmap results has an out_axes specification of {}, but is "
|
||||
r"actually mapped along more axes defined by this xmap call: "
|
||||
r"batch")
|
||||
with self.assertRaisesRegex(TypeError, error):
|
||||
xmap(vjp_f,
|
||||
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
|
||||
out_axes=({}, {0: 'batch'}))(w, x, gy)
|
||||
out = xmap(vjp_f,
|
||||
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
|
||||
out_axes=({0: 'batch'}, {0: 'batch'}))(w, x, gy)
|
||||
expected = vmap(vjp_f, in_axes=(None, 0, 0), out_axes=(0, 0))(w, x, gy)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
# reduced
|
||||
out = xmap(vjp_f_reduced,
|
||||
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
|
||||
out_axes=({}, {0: 'batch'}))(w, x, gy)
|
||||
# the reduced VJP is also the VJP when using a positional batch axis
|
||||
expected = vjp_f(w, x, gy)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
def testVjpReduceAxesCollective(self):
|
||||
|
||||
# lax.psum has the wrong transpose, so test with a corrected version for now
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
|
||||
def psum_idrev(x, axis_name: Optional[AxisNames] = None):
|
||||
if axis_name is None:
|
||||
return x
|
||||
return jax.lax.psum(x, axis_name)
|
||||
|
||||
def psum_idrev_fwd(x, axis_name):
|
||||
return psum_idrev(x, axis_name), None
|
||||
|
||||
def psum_idrev_bwd(axis_name, res, g):
|
||||
del axis_name, res
|
||||
return (g,)
|
||||
|
||||
psum_idrev.defvjp(psum_idrev_fwd, psum_idrev_bwd)
|
||||
|
||||
def f_named(w, x):
|
||||
return psum_idrev(jnp.sin(jnp.dot(x, w)).sum(), 'batch')
|
||||
|
||||
def f_positional(w, x):
|
||||
return jnp.sin(jnp.dot(x, w)).sum()
|
||||
|
||||
w = np.arange(12, dtype=np.float32).reshape(3, 4)
|
||||
x = np.arange(6, dtype=np.float32).reshape(2, 3)
|
||||
|
||||
# forward
|
||||
out = xmap(f_named, in_axes=({}, {0: 'batch'}), out_axes={})(w, x)
|
||||
expected = f_positional(w, x)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
# gradient
|
||||
out = xmap(jax.grad(f_named, (0, 1), reduce_axes=('batch',)),
|
||||
in_axes=({}, {0: 'batch'}),
|
||||
out_axes=({}, {0: 'batch'}))(w, x)
|
||||
expected = jax.grad(f_positional, (0, 1))(w, x)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user