mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
1674 lines
71 KiB
Python
1674 lines
71 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import dataclasses
|
|
from functools import update_wrapper, reduce, partial, wraps
|
|
from typing import Any, Generic, TypeVar
|
|
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import custom_api_util
|
|
from jax._src.custom_transpose import custom_transpose
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src import linear_util as lu
|
|
from jax._src import traceback_util
|
|
from jax._src.ad_util import (
|
|
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
|
from jax._src.api_util import (
|
|
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
|
|
from jax._src.core import raise_to_shaped
|
|
from jax._src.errors import UnexpectedTracerError
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import xla
|
|
from jax._src.interpreters.batching import not_mapped
|
|
from jax._src.lax import lax
|
|
from jax._src.tree_util import (
|
|
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
|
|
register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr,
|
|
treedef_children)
|
|
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
|
|
unzip2)
|
|
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map = safe_map
|
|
zip = safe_zip
|
|
|
|
|
|
### util
|
|
|
|
def _initial_style_jaxpr(fun, in_avals):
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
|
return jaxpr, consts
|
|
|
|
def _close_jaxpr(jaxpr):
|
|
return pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
|
|
|
def _sum_tangents(_, x, *xs):
|
|
return reduce(ad.add_tangents, xs, x)
|
|
|
|
def _zeros_like_pytree(x):
|
|
return tree_map(Zero.from_primal_value, x)
|
|
|
|
_stop_gradient = partial(
|
|
tree_map,
|
|
lambda x: stop_gradient_p.bind(x) if isinstance(x, core.Tracer) else x,
|
|
)
|
|
|
|
|
|
# like the api_util.py function, but also grabs output avals for error checking
|
|
@lu.transformation_with_aux
|
|
def _flatten_fun_nokwargs(in_tree, *args_flat):
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
ans = yield py_args, {}
|
|
ans_flat, ans_tree = tree_flatten(ans)
|
|
ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat]
|
|
yield ans_flat, (ans_tree, ans_avals)
|
|
|
|
|
|
### JVPs
|
|
ReturnValue = TypeVar('ReturnValue')
|
|
|
|
@custom_api_util.register_custom_decorator_type
|
|
class custom_jvp(Generic[ReturnValue]):
|
|
"""Set up a JAX-transformable function for a custom JVP rule definition.
|
|
|
|
This class is meant to be used as a function decorator. Instances are
|
|
callables that behave similarly to the underlying function to which the
|
|
decorator was applied, except when a differentiation transformation (like
|
|
:py:func:`jax.jvp` or :py:func:`jax.grad`) is applied, in which case a custom
|
|
user-supplied JVP rule function is used instead of tracing into and
|
|
performing automatic differentiation of the underlying function's
|
|
implementation.
|
|
|
|
There are two instance methods available for defining the custom JVP rule:
|
|
:py:func:`~jax.custom_jvp.defjvp` for defining a *single* custom JVP rule for
|
|
all the function's inputs, and for convenience
|
|
:py:func:`~jax.custom_jvp.defjvps`, which wraps
|
|
:py:func:`~jax.custom_jvp.defjvp`, and allows you to provide separate
|
|
definitions for the partial derivatives of the function w.r.t. each of its
|
|
arguments.
|
|
|
|
For example::
|
|
|
|
@jax.custom_jvp
|
|
def f(x, y):
|
|
return jnp.sin(x) * y
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
x, y = primals
|
|
x_dot, y_dot = tangents
|
|
primal_out = f(x, y)
|
|
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
|
|
return primal_out, tangent_out
|
|
|
|
For a more detailed introduction, see the tutorial_.
|
|
|
|
.. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
|
|
"""
|
|
fun: Callable[..., ReturnValue]
|
|
nondiff_argnums: Sequence[int]
|
|
jvp: Callable[..., tuple[ReturnValue, ReturnValue]] | None = None
|
|
symbolic_zeros: bool = False
|
|
|
|
def __init__(self,
|
|
fun: Callable[..., ReturnValue],
|
|
nondiff_argnums: Sequence[int] = (),
|
|
):
|
|
update_wrapper(self, fun)
|
|
self.fun = fun
|
|
self.nondiff_argnums = nondiff_argnums
|
|
|
|
__getattr__ = custom_api_util.forward_attr
|
|
|
|
def defjvp(self,
|
|
jvp: Callable[..., tuple[ReturnValue, ReturnValue]],
|
|
symbolic_zeros: bool = False,
|
|
) -> Callable[..., tuple[ReturnValue, ReturnValue]]:
|
|
"""Define a custom JVP rule for the function represented by this instance.
|
|
|
|
Args:
|
|
jvp: a Python callable representing the custom JVP rule. When there are no
|
|
``nondiff_argnums``, the ``jvp`` function should accept two arguments,
|
|
where the first is a tuple of primal inputs and the second is a tuple of
|
|
tangent inputs. The lengths of both tuples are equal to the number of
|
|
parameters of the :class:`~jax.custom_jvp` function. The ``jvp`` function
|
|
should produce as output a pair where the first element is the primal
|
|
output and the second element is the tangent output. Elements of the
|
|
input and output tuples may be arrays or any nested tuples/lists/dicts
|
|
thereof.
|
|
symbolic_zeros: boolean, indicating whether the rule should be passed
|
|
objects representing static symbolic zeros in its tangent argument in
|
|
correspondence with unperturbed values; otherwise, only standard JAX
|
|
types (e.g. array-likes) are passed. Setting this option to ``True``
|
|
allows a JVP rule to detect whether certain inputs are not involved in
|
|
differentiation, but at the cost of needing special handling for these
|
|
objects (which e.g. can't be passed into jax.numpy functions). Default
|
|
``False``.
|
|
|
|
Returns:
|
|
Returns ``jvp`` so that ``defjvp`` can be used as a decorator.
|
|
|
|
Examples:
|
|
|
|
>>> @jax.custom_jvp
|
|
... def f(x, y):
|
|
... return jnp.sin(x) * y
|
|
...
|
|
>>> @f.defjvp
|
|
... def f_jvp(primals, tangents):
|
|
... x, y = primals
|
|
... x_dot, y_dot = tangents
|
|
... primal_out = f(x, y)
|
|
... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
|
|
... return primal_out, tangent_out
|
|
|
|
>>> x = jnp.float32(1.0)
|
|
>>> y = jnp.float32(2.0)
|
|
>>> with jnp.printoptions(precision=2):
|
|
... print(jax.value_and_grad(f)(x, y))
|
|
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))
|
|
"""
|
|
self.jvp = jvp
|
|
self.symbolic_zeros = symbolic_zeros
|
|
return jvp
|
|
|
|
def defjvps(self, *jvps: Callable[..., ReturnValue] | None) -> None:
|
|
"""Convenience wrapper for defining JVPs for each argument separately.
|
|
|
|
This convenience wrapper cannot be used together with ``nondiff_argnums``.
|
|
|
|
Args:
|
|
*jvps: a sequence of functions, one for each positional argument of the
|
|
:class:`~jax.custom_jvp` function. Each function takes as arguments
|
|
the tangent value for the corresponding primal input, the primal
|
|
output, and the ßprimal inputs. See the example below.
|
|
|
|
Returns:
|
|
None.
|
|
|
|
Examples:
|
|
|
|
>>> @jax.custom_jvp
|
|
... def f(x, y):
|
|
... return jnp.sin(x) * y
|
|
...
|
|
>>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
|
|
... lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
|
|
|
|
>>> x = jnp.float32(1.0)
|
|
>>> y = jnp.float32(2.0)
|
|
>>> with jnp.printoptions(precision=2):
|
|
... print(jax.value_and_grad(f)(x, y))
|
|
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))
|
|
"""
|
|
if self.nondiff_argnums:
|
|
raise TypeError("Can't use ``defjvps`` with ``nondiff_argnums``.")
|
|
|
|
def jvp(primals, tangents):
|
|
primal_out = self(*primals)
|
|
zeros = _zeros_like_pytree(primal_out)
|
|
all_tangents_out = [jvp(t, primal_out, *primals) if jvp else zeros
|
|
for t, jvp in zip(tangents, jvps)]
|
|
tangent_out = tree_map(_sum_tangents, primal_out, *all_tangents_out)
|
|
return primal_out, tangent_out
|
|
|
|
self.defjvp(jvp)
|
|
|
|
@traceback_util.api_boundary
|
|
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
|
|
primal_name = getattr(self.fun, '__name__', str(self.fun))
|
|
if not self.jvp:
|
|
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
|
|
raise AttributeError(msg)
|
|
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))
|
|
args = resolve_kwargs(self.fun, args, kwargs)
|
|
if self.nondiff_argnums:
|
|
nondiff_argnums = set(self.nondiff_argnums)
|
|
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
|
|
for i, x in enumerate(args))
|
|
diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
|
|
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args,
|
|
require_static_args_hashable=False)
|
|
static_args = [args[i] for i in self.nondiff_argnums]
|
|
jvp = _add_args(lu.wrap_init(self.jvp), static_args)
|
|
else:
|
|
f_, dyn_args = lu.wrap_init(self.fun), args
|
|
jvp = lu.wrap_init(self.jvp)
|
|
args_flat, in_tree = tree_flatten(dyn_args)
|
|
flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree)
|
|
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree,
|
|
out_type1)
|
|
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat,
|
|
symbolic_zeros=self.symbolic_zeros)
|
|
_, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2)
|
|
return tree_unflatten(out_tree, out_flat)
|
|
|
|
def _add_args(f, extra_args):
|
|
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))
|
|
|
|
@lu.transformation
|
|
def _add_args_(extra_args, *args, **kwargs):
|
|
extra_args = tuple(arg.val for arg in extra_args)
|
|
all_args = (extra_args + args)
|
|
yield (yield all_args, kwargs)
|
|
|
|
@partial(lu.transformation_with_aux, use_eq_store=True)
|
|
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
|
|
primals_in, tangents_in = split_list(args, [len(args) // 2])
|
|
py_primals = tree_unflatten(in_tree, primals_in)
|
|
py_tangents = tree_unflatten(in_tree, tangents_in)
|
|
pair_out = yield (py_primals, py_tangents), {}
|
|
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
|
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
|
|
"must produce a pair (list or tuple of length two) representing "
|
|
f"primal and tangent outputs, but got {pair_out}.")
|
|
raise TypeError(msg)
|
|
py_primals_out, py_tangents_out = pair_out
|
|
primals_out, out_tree = tree_flatten(py_primals_out)
|
|
tangents_out, out_tree2 = tree_flatten(py_tangents_out)
|
|
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
|
if out_tree != out_tree2:
|
|
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
|
|
"produce primal and tangent outputs with equal container (pytree) "
|
|
f"structures, but got {out_tree} and {out_tree2} respectively.")
|
|
raise TypeError(msg)
|
|
# If the primal function already ran, check out_tree agreement.
|
|
try: out_type_ = maybe_out_type()
|
|
except lu.StoreException: out_type_ = None
|
|
if out_type_ is not None:
|
|
out_tree_, primal_avals_ = out_type_
|
|
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
|
|
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
|
|
if out_tree_ != out_tree:
|
|
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
|
|
"produce a pair (list or tuple of length two) "
|
|
"where the first element represents the primal output "
|
|
"(equal in value to the output of the custom_jvp-decorated function "
|
|
f"{primal_name}, "
|
|
"and in particular of the same container/pytree structure), but "
|
|
"instead the JVP rule output's first element had container/pytree "
|
|
"structure:\n"
|
|
f""" {str(ty_tree ).replace("'", "")}\n"""
|
|
f"while the custom_jvp-decorated function {primal_name} had output "
|
|
"container/pytree structure:\n"
|
|
f""" {str(ty_tree_).replace("'", "")}.""")
|
|
raise TypeError(m)
|
|
if not all(map(core.typematch, primal_avals, primal_avals_)):
|
|
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
|
|
"produce a pair (list or tuple of length two) "
|
|
"where the first element represents the primal output "
|
|
"(equal in value to the output of the custom_jvp-decorated function "
|
|
f"{primal_name}, "
|
|
"and in particular with leaves of the same shape/dtype), but "
|
|
"instead the JVP rule output's first element had shapes/dtypes of:\n"
|
|
f""" {str(ty_tree ).replace("'", "")}\n"""
|
|
f"while the custom_jvp-decorated function {primal_name} had output "
|
|
"shapes/dtypes of:\n"
|
|
f""" {str(ty_tree_).replace("'", "")}""")
|
|
raise TypeError(m)
|
|
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
|
|
expected_tangent_avals_out = [
|
|
raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval()
|
|
for x in primals_out]
|
|
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
|
|
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
|
|
for t in tangents_out]
|
|
if expected_tangent_avals_out != tangent_avals_out:
|
|
if len(expected_tangent_avals_out) == 1:
|
|
(av_p,), (av_et,), (av_t,) = primal_avals_out, expected_tangent_avals_out, tangent_avals_out
|
|
msg = ("Custom JVP rule must produce primal and tangent outputs with "
|
|
"corresponding shapes and dtypes. Expected {} (tangent type of {}) but got {}.")
|
|
raise TypeError(msg.format(av_et.str_short(), av_p.str_short(), av_t.str_short()))
|
|
else:
|
|
msg = ("Custom JVP rule must produce primal and tangent outputs with "
|
|
"corresponding shapes and dtypes, but got:\n{}")
|
|
disagreements = (
|
|
f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
|
|
for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
|
|
if av_et != av_t)
|
|
|
|
raise TypeError(msg.format('\n'.join(disagreements)))
|
|
yield primals_out + tangents_out, (out_tree, primal_avals)
|
|
|
|
class CustomJVPCallPrimitive(core.Primitive):
|
|
multiple_results = True
|
|
|
|
def bind(self, fun, jvp, *args, symbolic_zeros):
|
|
args = map(core.full_lower, args)
|
|
top_trace = core.find_top_trace(args)
|
|
fun, env_trace_todo1 = process_env_traces(
|
|
fun, self, top_trace and top_trace.level, False)
|
|
jvp, env_trace_todo2 = process_env_traces(
|
|
jvp, self, top_trace and top_trace.level, True)
|
|
tracers = map(top_trace.full_raise, args)
|
|
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers,
|
|
symbolic_zeros=symbolic_zeros)
|
|
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
|
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
|
|
|
|
def impl(self, fun, _, *args):
|
|
with core.new_sublevel():
|
|
return fun.call_wrapped(*args)
|
|
|
|
def post_process(self, trace, out_tracers, jvp_was_run: bool):
|
|
return trace.post_process_custom_jvp_call(out_tracers, jvp_was_run)
|
|
|
|
def get_bind_params(self, params):
|
|
new_params = dict(params)
|
|
call_jaxpr = new_params.pop('call_jaxpr')
|
|
num_consts = new_params.pop('num_consts')
|
|
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
|
|
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr))
|
|
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk)
|
|
return [fun, jvp], new_params
|
|
|
|
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable) -> lu.WrappedFun:
|
|
@lu.wrap_init
|
|
def jvp(*xs):
|
|
n, ragged = divmod(len(xs), 2)
|
|
assert not ragged
|
|
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
|
|
zeros = [type(t) is SymbolicZero for t in tangents]
|
|
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
|
|
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
|
|
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
|
|
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
|
|
nz_out_tangents_ = iter(nz_out_tangents)
|
|
out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval())
|
|
if z else next(nz_out_tangents_)
|
|
for p, z in zip(out_primals, out_zeros)]
|
|
assert next(nz_out_tangents_, None) is None
|
|
return [*out_primals, *out_tangents]
|
|
return jvp
|
|
|
|
@partial(lu.transformation_with_aux, use_eq_store=True)
|
|
def process_env_traces(primitive, level: int, jvp_was_run: bool, *args):
|
|
outs = yield args, {}
|
|
todo = []
|
|
while True:
|
|
tracers = [x for x in outs if isinstance(x, core.Tracer)
|
|
and (level is None or x._trace.level > level)]
|
|
if tracers:
|
|
ans = max(tracers, key=lambda x: x._trace.level)
|
|
else:
|
|
break
|
|
trace = ans._trace.main.with_cur_sublevel()
|
|
outs = map(trace.full_raise, outs)
|
|
outs, cur_todo = primitive.post_process(trace, outs, jvp_was_run)
|
|
todo.append(cur_todo)
|
|
yield outs, tuple(todo) # Ensure the aux output is immutable
|
|
|
|
|
|
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
|
|
|
|
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
|
|
|
|
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
|
|
num_consts, symbolic_zeros):
|
|
# TODO(mattjj): could do more checking here...
|
|
del in_avals, jvp_jaxpr_thunk, num_consts
|
|
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects)
|
|
if disallowed_effects:
|
|
raise NotImplementedError(
|
|
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
|
|
return call_jaxpr.out_avals, call_jaxpr.effects
|
|
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
|
|
|
|
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
|
|
num_consts, symbolic_zeros):
|
|
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
|
|
consts = mlir._ir_consts(call_jaxpr.consts)
|
|
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
|
|
ctx.name_stack, ctx.tokens_in, consts,
|
|
*args, dim_var_values=ctx.dim_var_values)
|
|
ctx.set_tokens_out(tokens)
|
|
return out
|
|
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation)
|
|
|
|
# If a (multi)linear function is defined with a custom jvp, then
|
|
# custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already
|
|
# been linearized, we can drop the jvp rule.
|
|
def _custom_jvp_call_transpose(params, jaxpr, args, ct, _):
|
|
del params
|
|
return ad.backward_pass(jaxpr.jaxpr, None, jaxpr.consts, args, ct)
|
|
ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose
|
|
|
|
|
|
### VJPs
|
|
|
|
@custom_api_util.register_custom_decorator_type
|
|
class custom_vjp(Generic[ReturnValue]):
|
|
"""Set up a JAX-transformable function for a custom VJP rule definition.
|
|
|
|
This class is meant to be used as a function decorator. Instances are
|
|
callables that behave similarly to the underlying function to which the
|
|
decorator was applied, except when a reverse-mode differentiation
|
|
transformation (like :py:func:`jax.grad`) is applied, in which case a custom
|
|
user-supplied VJP rule function is used instead of tracing into and performing
|
|
automatic differentiation of the underlying function's implementation. There
|
|
is a single instance method, :py:func:`~jax.custom_vjp.defvjp`, which may be
|
|
used to define the custom VJP rule.
|
|
|
|
This decorator precludes the use of forward-mode automatic differentiation.
|
|
|
|
For example::
|
|
|
|
@jax.custom_vjp
|
|
def f(x, y):
|
|
return jnp.sin(x) * y
|
|
|
|
def f_fwd(x, y):
|
|
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
|
|
|
|
def f_bwd(res, g):
|
|
cos_x, sin_x, y = res
|
|
return (cos_x * g * y, sin_x * g)
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
For a more detailed introduction, see the tutorial_.
|
|
|
|
.. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
|
|
"""
|
|
|
|
def __init__(self,
|
|
fun: Callable[..., ReturnValue],
|
|
nondiff_argnums: Sequence[int] = ()):
|
|
update_wrapper(self, fun)
|
|
self.fun = fun
|
|
self.nondiff_argnums = nondiff_argnums
|
|
self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None
|
|
self.bwd: Callable[..., tuple[Any, ...]] | None = None
|
|
self.symbolic_zeros = False
|
|
self.optimize_remat = False
|
|
|
|
__getattr__ = custom_api_util.forward_attr
|
|
|
|
def defvjp(self,
|
|
fwd: Callable[..., tuple[ReturnValue, Any]],
|
|
bwd: Callable[..., tuple[Any, ...]],
|
|
symbolic_zeros: bool = False,
|
|
optimize_remat: bool = False,
|
|
) -> None:
|
|
"""Define a custom VJP rule for the function represented by this instance.
|
|
|
|
Args:
|
|
fwd: a Python callable representing the forward pass of the custom VJP
|
|
rule. When there are no ``nondiff_argnums``, the ``fwd`` function has
|
|
the same input signature as the underlying primal function. It should
|
|
return as output a pair, where the first element represents the primal
|
|
output and the second element represents any "residual" values to store
|
|
from the forward pass for use on the backward pass by the function
|
|
``bwd``. Input arguments and elements of the output pair may be arrays
|
|
or nested tuples/lists/dicts thereof.
|
|
bwd: a Python callable representing the backward pass of the custom VJP
|
|
rule. When there are no ``nondiff_argnums``, the ``bwd`` function takes
|
|
two arguments, where the first is the "residual" values produced on the
|
|
forward pass by ``fwd``, and the second is the output cotangent with the
|
|
same structure as the primal function output. The output of ``bwd`` must
|
|
be a tuple of length equal to the number of arguments of the primal
|
|
function, and the tuple elements may be arrays or nested
|
|
tuples/lists/dicts thereof so as to match the structure of the primal
|
|
input arguments.
|
|
symbolic_zeros: boolean, determining whether to indicate symbolic zeros
|
|
to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom
|
|
derivative rules to detect when certain inputs, and when certain
|
|
output cotangents, are not involved in differentiation. If ``True``:
|
|
|
|
* ``fwd`` must accept, in place of each leaf value ``x`` in
|
|
the pytree comprising an argument to the original function,
|
|
an object (of type
|
|
``jax.custom_derivatives.CustomVJPPrimal``) with two
|
|
attributes instead: ``value`` and ``perturbed``. The
|
|
``value`` field is the original primal argument, and
|
|
``perturbed`` is a boolean. The ``perturbed`` bit indicates
|
|
whether the argument is involved in differentiation (i.e.,
|
|
if it is ``False``, then the corresponding Jacobian "column"
|
|
is zero).
|
|
|
|
* ``bwd`` will be passed objects representing static symbolic zeros in
|
|
its cotangent argument in correspondence with unperturbed values;
|
|
otherwise, only standard JAX types (e.g. array-likes) are passed.
|
|
|
|
Setting this option to ``True`` allows these rules to detect whether
|
|
certain inputs and outputs are not involved in differentiation, but at
|
|
the cost of special handling. For instance:
|
|
|
|
* The signature of ``fwd`` changes, and the objects it is passed cannot
|
|
be output from the rule directly.
|
|
|
|
* The ``bwd`` rule is passed objects that are not entirely array-like,
|
|
and that cannot be passed to most ``jax.numpy`` functions.
|
|
|
|
* Any custom pytree nodes involved in the primal function's arguments
|
|
must accept, in their unflattening functions, the two-field record
|
|
objects that are given as input leaves to the ``fwd`` rule.
|
|
|
|
Default ``False``.
|
|
optimize_remat: boolean, an experimental flag to enable an automatic
|
|
optimization when this function is used under :func:`jax.remat`. This
|
|
will be most useful when the ``fwd`` rule is an opaque call such as a
|
|
Pallas kernel or a custom call. Default ``False``.
|
|
|
|
Returns:
|
|
None.
|
|
|
|
Examples:
|
|
|
|
>>> @jax.custom_vjp
|
|
... def f(x, y):
|
|
... return jnp.sin(x) * y
|
|
...
|
|
>>> def f_fwd(x, y):
|
|
... return f(x, y), (jnp.cos(x), jnp.sin(x), y)
|
|
...
|
|
>>> def f_bwd(res, g):
|
|
... cos_x, sin_x, y = res
|
|
... return (cos_x * g * y, sin_x * g)
|
|
...
|
|
>>> f.defvjp(f_fwd, f_bwd)
|
|
|
|
>>> x = jnp.float32(1.0)
|
|
>>> y = jnp.float32(2.0)
|
|
>>> with jnp.printoptions(precision=2):
|
|
... print(jax.value_and_grad(f)(x, y))
|
|
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))
|
|
"""
|
|
self.fwd = fwd
|
|
self.bwd = bwd
|
|
self.symbolic_zeros = symbolic_zeros
|
|
self.optimize_remat = optimize_remat
|
|
if self.symbolic_zeros and self.optimize_remat:
|
|
raise NotImplementedError(
|
|
"remat optimization for custom_vjp does not support symbolic zeros")
|
|
|
|
@traceback_util.api_boundary
|
|
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
|
|
primal_name = getattr(self.fun, '__name__', str(self.fun))
|
|
if not self.fwd or not self.bwd:
|
|
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
|
|
raise AttributeError(msg)
|
|
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
|
|
args = resolve_kwargs(self.fun, args, kwargs)
|
|
if self.optimize_remat:
|
|
fwd = optimize_remat_of_custom_vjp_fwd(
|
|
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums,
|
|
symbolic_zeros=self.symbolic_zeros)
|
|
else:
|
|
fwd = self.fwd
|
|
if config.enable_custom_vjp_by_custom_transpose.value:
|
|
if self.nondiff_argnums:
|
|
raise NotImplementedError(
|
|
'nondiff_argnums not implemented for new custom_vjp')
|
|
return custom_vjp_by_custom_transpose(self.fun, self.fwd, self.bwd)(*args)
|
|
else:
|
|
if self.nondiff_argnums:
|
|
for i in self.nondiff_argnums: _check_for_tracers(args[i])
|
|
nondiff_argnums = set(self.nondiff_argnums)
|
|
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
|
|
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
|
|
args, require_static_args_hashable=False)
|
|
static_args = [args[i] for i in self.nondiff_argnums]
|
|
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
|
|
require_static_args_hashable=False)
|
|
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
|
|
else:
|
|
f_, dyn_args = lu.wrap_init(self.fun), args
|
|
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
|
|
args_flat, in_tree = tree_flatten(dyn_args)
|
|
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
|
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
|
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
|
|
fwd_name, in_tree, out_type)
|
|
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
|
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
|
*args_flat, out_trees=out_trees,
|
|
symbolic_zeros=self.symbolic_zeros)
|
|
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
|
|
return tree_unflatten(out_tree, out_flat)
|
|
|
|
@dataclasses.dataclass
|
|
class CustomVJPPrimal:
|
|
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
|
|
value: Any
|
|
perturbed: bool
|
|
|
|
def custom_vjp_primal_tree_values(tree):
|
|
"""Strips away perturbation information from forward rule arguments.
|
|
|
|
This is a helper function for user with the ``symbolic_zeros`` option to
|
|
the ``defvjp`` method of a ``custom_vjp``-decorated function.
|
|
|
|
In ``symbolic_zeros`` mode, the custom forward rule receives arguments
|
|
whose pytree leaves are records with a ``value`` attribute that carries
|
|
the primal argument. This is a way to convert such argument trees back to
|
|
their original form, replacing each such record with its carried value at
|
|
each leaf.
|
|
"""
|
|
def value(leaf):
|
|
if type(leaf) is not CustomVJPPrimal:
|
|
raise TypeError(f"unexpected leaf type {type(leaf)}")
|
|
return leaf.value
|
|
return tree_map(value, tree)
|
|
|
|
def _check_for_tracers(x):
|
|
for leaf in tree_leaves(x):
|
|
if isinstance(leaf, core.Tracer):
|
|
msg = ("Found a JAX Tracer object passed as an argument to a custom_vjp "
|
|
"function in a position indicated by nondiff_argnums as "
|
|
"non-differentiable. Tracers cannot be passed as non-differentiable "
|
|
"arguments to custom_vjp functions; instead, nondiff_argnums should "
|
|
"only be used for arguments that can't be or contain JAX tracers, "
|
|
"e.g. function-valued arguments. In particular, array-valued "
|
|
"arguments should typically not be indicated as nondiff_argnums.")
|
|
raise UnexpectedTracerError(msg)
|
|
|
|
@partial(lu.transformation_with_aux, use_eq_store=True)
|
|
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
|
*args):
|
|
if symbolic_zeros:
|
|
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
|
|
else:
|
|
args = args[::2]
|
|
py_args = tree_unflatten(in_tree, args)
|
|
pair_out = yield py_args, {}
|
|
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
|
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
|
|
"must produce a pair (list or tuple of length two) where the first "
|
|
"element represents the primal output (equal to those of the "
|
|
f"custom_vjp-decorated function {primal_name}) and the "
|
|
"second element represents residuals (i.e. values stored from the "
|
|
"forward pass for use on the backward pass), but "
|
|
f"instead of a pair the fwd rule {fwd_name} produced {pair_out}.")
|
|
raise TypeError(msg)
|
|
py_primals_out, res = pair_out
|
|
primals_out, out_tree = tree_flatten(py_primals_out)
|
|
res, res_tree = tree_flatten(res)
|
|
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
|
# If the primal function already ran, check out_tree agreement.
|
|
try: out_type_ = maybe_out_type()
|
|
except lu.StoreException: out_type_ = None
|
|
if out_type_ is not None:
|
|
out_tree_, primal_avals_ = out_type_
|
|
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
|
|
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
|
|
if out_tree_ != out_tree:
|
|
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
|
|
"must produce a pair (list or tuple of length two) where the first "
|
|
"element represents the primal output "
|
|
"(equal to the output of the custom_vjp-decorated function "
|
|
f"{primal_name}) and the "
|
|
"second element represents residuals (i.e. values stored from the "
|
|
"forward pass for use on the backward pass), but "
|
|
"instead the fwd rule output's first element had container/pytree "
|
|
"structure:\n"
|
|
f""" {str(ty_tree ).replace("'", "")}\n"""
|
|
f"while the custom_vjp-decorated function {primal_name} had output "
|
|
"container/pytree structure:\n"
|
|
f""" {str(ty_tree_).replace("'", "")}.""")
|
|
raise TypeError(m)
|
|
if not all(map(core.typematch, primal_avals, primal_avals_)):
|
|
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} must "
|
|
"produce a pair (list or tuple of length two) "
|
|
"where the first element represents the primal output "
|
|
"(equal to the output of the custom_vjp-decorated function "
|
|
f"{primal_name}) and the second element represents residuals "
|
|
"(i.e. values stored from the forward pass for use on the "
|
|
"backward pass), but "
|
|
"instead the fwd rule output's first element had shapes/dtypes of:\n"
|
|
f""" {str(ty_tree ).replace("'", "")}\n"""
|
|
f"while the custom_vjp-decorated function {primal_name} had output "
|
|
"shapes/dtypes of:\n"
|
|
f""" {str(ty_tree_).replace("'", "")}""")
|
|
raise TypeError(m)
|
|
yield (*res, *primals_out), (out_tree, res_tree)
|
|
|
|
@lu.transformation
|
|
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
|
out_tree, res_tree = out_trees()
|
|
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
|
|
res, cts_out = split_list(args, [res_tree.num_leaves])
|
|
py_res = tree_unflatten(res_tree, res)
|
|
py_cts_out = tree_unflatten(out_tree, cts_out)
|
|
py_cts_in = yield (py_res, py_cts_out), {}
|
|
if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)):
|
|
py_cts_in = tuple(py_cts_in)
|
|
# For each None in py_cts_in, indicating an argument for which the rule
|
|
# produces no cotangent, we replace it with a pytree with the structure of the
|
|
# corresponding subtree of in_tree and with leaves of a non-pytree sentinel
|
|
# object, to be replaced with Nones in the final returned result.
|
|
zero = object() # non-pytree sentinel to replace Nones in py_cts_in
|
|
dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
|
|
keypaths, _ = unzip2(tree_flatten_with_path(dummy)[0])
|
|
cts_in_flat = []
|
|
def append(x, d):
|
|
num_leaves = len(tree_flatten(d)[0])
|
|
if x is None and d is not None:
|
|
cts_in_flat.extend([zero] * num_leaves)
|
|
elif x is not None:
|
|
cts_in_flat.extend([x] * num_leaves)
|
|
return x
|
|
try:
|
|
if not isinstance(py_cts_in, tuple):
|
|
raise ValueError
|
|
tree_map(append, py_cts_in, dummy, is_leaf=lambda x: x is None)
|
|
except ValueError:
|
|
_, in_tree2 = tree_flatten(py_cts_in)
|
|
msg = ("Custom VJP bwd rule must produce an output with the same container "
|
|
"(pytree) structure as the args tuple of the primal function, "
|
|
"and in particular must produce a tuple of length equal to the "
|
|
"number of arguments to the primal function, but got bwd output "
|
|
"structure {} for primal input structure {}.")
|
|
raise TypeError(msg.format(in_tree2, in_tree)) from None
|
|
results = []
|
|
for kp, a, ct in zip(keypaths, in_avals, cts_in_flat):
|
|
if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0:
|
|
results.append(Zero(a.to_tangent_aval()))
|
|
elif type(ct) is SymbolicZero:
|
|
if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval):
|
|
msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype "
|
|
"that does not match the corresponding input tangent shape/dtype: "
|
|
f"at output{keystr(kp)} the SymbolicZero had shape/dtype "
|
|
f"{a_.str_short()} while the "
|
|
f"corresponding input had shape/dtype {a.str_short()}. "
|
|
"Consider just returning a None here instead of a SymbolicZero "
|
|
"object.")
|
|
raise ValueError(msg)
|
|
results.append(Zero(ct.aval))
|
|
else:
|
|
if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct))
|
|
and not (_temporary_dtype_exception(a, a_) or
|
|
_temporary_shape_exception(a, a_))):
|
|
msg = ("Custom VJP bwd rule must produce an output with the same "
|
|
"shape/dtypes as the args tuple of the primal function, but at "
|
|
f"output{keystr(kp)} the bwd rule produced an output of "
|
|
f"shape/dtype {raise_to_shaped(a_).str_short()} corresponding "
|
|
f"to an input of shape/dtype {a.str_short()}.")
|
|
raise ValueError(msg)
|
|
results.append(ct)
|
|
yield results
|
|
|
|
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
|
|
def _temporary_dtype_exception(a, a_) -> bool:
|
|
if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray):
|
|
return (a.shape == a_.shape and
|
|
(dtypes.issubdtype(a_.dtype, dtypes.extended) or
|
|
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
|
|
return False
|
|
|
|
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
|
|
def _temporary_shape_exception(a, a_) -> bool:
|
|
return config.custom_vjp_disable_shape_check.value
|
|
|
|
class CustomVJPCallPrimitive(core.CallPrimitive):
|
|
initial_style: core.Primitive
|
|
|
|
def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros):
|
|
args = map(core.full_lower, args)
|
|
top_trace = core.find_top_trace(args)
|
|
fun, env_trace_todo1 = process_env_traces(
|
|
fun, self, top_trace and top_trace.level, False)
|
|
fwd, env_trace_todo2 = process_env_traces_fwd(
|
|
fwd, top_trace and top_trace.level, out_trees)
|
|
tracers = map(top_trace.full_raise, args)
|
|
bwd_ = lambda *args: bwd(*args)
|
|
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
|
|
out_trees=out_trees,
|
|
symbolic_zeros=symbolic_zeros)
|
|
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
|
if fst:
|
|
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
|
|
else:
|
|
env_trace_todo, bwd_transform = env_trace_todo
|
|
bwd = _apply_bwd_transform(bwd_transform, bwd)
|
|
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
|
|
|
|
def impl(self, fun, fwd, bwd, *args, out_trees):
|
|
del fwd, bwd, out_trees
|
|
with core.new_sublevel():
|
|
return fun.call_wrapped(*args)
|
|
|
|
def post_process(self, trace, out_tracers, params):
|
|
return trace.post_process_custom_vjp_call(out_tracers, params)
|
|
custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
|
|
|
|
@partial(lu.transformation_with_aux, use_eq_store=True)
|
|
def process_env_traces_fwd(level: int, out_trees, *args):
|
|
outs = yield args, {}
|
|
todo = []
|
|
bwd_transforms = []
|
|
while True:
|
|
tracers = [x for x in outs if isinstance(x, core.Tracer)
|
|
and (level is None or x._trace.level > level)]
|
|
if tracers:
|
|
ans = max(tracers, key=lambda x: x._trace.level)
|
|
else:
|
|
break
|
|
trace = ans._trace.main.with_cur_sublevel()
|
|
outs = map(trace.full_raise, outs)
|
|
outs, cur_todo, bwd_xform = trace.post_process_custom_vjp_call_fwd(outs, out_trees)
|
|
todo.append(cur_todo)
|
|
bwd_transforms.append(bwd_xform)
|
|
yield outs, (tuple(todo), tuple(bwd_transforms))
|
|
|
|
|
|
def _apply_bwd_transform(todos, bwd):
|
|
todos_list = list(todos)
|
|
while todos_list:
|
|
bwd = todos_list.pop()(bwd)
|
|
return bwd
|
|
|
|
def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
|
|
return core.jaxpr_as_fun(fun_jaxpr)(*args)
|
|
|
|
def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
|
|
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fun_jaxpr.effects)
|
|
if disallowed_effects:
|
|
raise NotImplementedError(
|
|
f'Effects not supported in `custom_vjp`: {disallowed_effects}')
|
|
return fun_jaxpr.out_avals, fun_jaxpr.effects
|
|
|
|
custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr')
|
|
custom_vjp_call_jaxpr_p.multiple_results = True
|
|
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
|
|
custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
|
|
CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p
|
|
|
|
mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
|
|
_custom_vjp_call_jaxpr_impl, multiple_results=True))
|
|
|
|
def _custom_vjp_call_jaxpr_jvp(
|
|
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
|
|
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
|
|
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
|
_, args = split_list(primals, [num_consts])
|
|
consts_dot, args_dot = split_list(tangents, [num_consts])
|
|
if any(type(t) is not Zero for t in consts_dot):
|
|
raise ad.CustomVJPException()
|
|
zeros = [type(t) is not Zero for t in args_dot]
|
|
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers!
|
|
_, res_tree = out_trees()
|
|
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
|
|
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
|
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
|
|
args_dot = map(ad.instantiate_zeros, args_dot)
|
|
tangents_out = ad.custom_lin_p.bind(
|
|
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
|
|
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
|
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
|
|
return primals_out, tangents_out
|
|
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
|
|
|
def _custom_vjp_call_jaxpr_vmap(
|
|
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
|
|
fun_jaxpr: core.ClosedJaxpr,
|
|
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
|
|
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
|
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
|
else x for x, d in zip(args, in_dims)]
|
|
|
|
in_batched = [d is not not_mapped for d in in_dims]
|
|
_, args_batched = split_list(in_batched, [num_consts])
|
|
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
|
|
fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name,
|
|
main_type)
|
|
out_dims1 = [0 if b else not_mapped for b in out_batched]
|
|
out_dims2 = []
|
|
|
|
@pe._memoize
|
|
def batched_fwd_jaxpr_thunk(*zeros):
|
|
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
|
|
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
|
|
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name,
|
|
main_type)
|
|
out_dims2.append([0 if b else not_mapped for b in out_batched])
|
|
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
|
|
|
|
fwd_args_batched = [0 if b else not_mapped for b in args_batched]
|
|
fwd_out_dims = lambda: out_dims2[0]
|
|
batched_bwd = batching.batch_custom_vjp_bwd(
|
|
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type,
|
|
spmd_axis_name)
|
|
|
|
batched_outs = custom_vjp_call_jaxpr_p.bind(
|
|
*args, fun_jaxpr=batched_fun_jaxpr,
|
|
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
|
|
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
|
|
out_dims = out_dims2[0] if out_dims2 else out_dims1
|
|
return batched_outs, out_dims
|
|
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
|
|
_custom_vjp_call_jaxpr_vmap
|
|
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
|
|
_custom_vjp_call_jaxpr_vmap, None)
|
|
|
|
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
|
|
|
|
batching.primitive_batchers[ad.custom_lin_p] = ad.raise_custom_vjp_error_on_jvp
|
|
mlir.register_lowering(ad.custom_lin_p, ad.raise_custom_vjp_error_on_jvp)
|
|
|
|
|
|
def custom_gradient(fun):
|
|
"""Convenience function for defining custom VJP rules (aka custom gradients).
|
|
|
|
While the canonical way to define custom VJP rules is via ``jax.custom_vjp``,
|
|
the ``custom_gradient`` convenience wrapper follows TensorFlow's
|
|
``tf.custom_gradient`` API. The difference here is that ``custom_gradient``
|
|
can be used as a decorator on one function that returns both the primal value
|
|
(representing the output of the mathematical function to be differentiated)
|
|
and the VJP (gradient) function. See
|
|
https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
|
|
|
|
If the mathematical function to be differentiated has Haskell-like signature
|
|
``a -> b``, then the Python callable ``fun`` should have the signature
|
|
``a -> (b, CT b --o CT a)`` where we use ``CT x`` to denote a cotangent type
|
|
for ``x`` and the ``--o`` arrow to denote a linear function. See the example
|
|
below. That is, ``fun`` should return a pair where the first element
|
|
represents the value of the mathematical function to be differentiated and the
|
|
second element is a function to be called on the backward pass of reverse-mode
|
|
automatic differentiation (i.e. the "custom gradient" function).
|
|
|
|
The function returned as the second element of the output of ``fun`` can close
|
|
over intermediate values computed when evaluating the function to be
|
|
differentiated. That is, use lexical closure to share work between the forward
|
|
pass and the backward pass of reverse-mode automatic differentiation. However,
|
|
it cannot perform Python control flow which depends on the values of the
|
|
closed-over intermediate values or its cotangent arguments; if the function
|
|
includes such control flow, an error is raised.
|
|
|
|
Args:
|
|
fun: a Python callable specifying both the mathematical function to be
|
|
differentiated and its reverse-mode differentiation rule. It should return
|
|
a pair consisting of an output value and a Python callable that represents
|
|
the custom gradient function.
|
|
|
|
Returns:
|
|
A Python callable that accepts the same arguments as ``fun`` and returns the
|
|
output value specified by the first element of ``fun``'s output pair.
|
|
|
|
For example:
|
|
|
|
>>> @jax.custom_gradient
|
|
... def f(x):
|
|
... return x ** 2, lambda g: (g * x,)
|
|
...
|
|
>>> print(f(3.))
|
|
9.0
|
|
>>> print(jax.grad(f)(3.))
|
|
3.0
|
|
|
|
An example with a function on two arguments, so that the VJP function must
|
|
return a tuple of length two:
|
|
|
|
>>> @jax.custom_gradient
|
|
... def f(x, y):
|
|
... return x * y, lambda g: (g * y, g * x)
|
|
...
|
|
>>> print(f(3., 4.))
|
|
12.0
|
|
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
|
|
(Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))
|
|
"""
|
|
@custom_vjp
|
|
def wrapped_fun(*args, **kwargs):
|
|
ans, _ = fun(*args, **kwargs)
|
|
return ans
|
|
|
|
def fwd(*args, **kwargs):
|
|
ans, rule = fun(*args, **kwargs)
|
|
ans_flat, out_tree = tree_flatten((ans,))
|
|
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
|
|
ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat]
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
|
|
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
|
|
|
|
def bwd(res, cts):
|
|
jaxpr, in_tree, out_tree, consts = res
|
|
cts_flat, out_tree_ = tree_flatten((cts,))
|
|
if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
|
|
cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
|
|
cts_out = tree_unflatten(in_tree, cts_out)
|
|
if treedef_is_leaf(in_tree):
|
|
cts_out = (cts_out,)
|
|
return cts_out
|
|
|
|
wrapped_fun.defvjp(fwd, bwd)
|
|
return wrapped_fun
|
|
|
|
@register_pytree_node_class
|
|
class Residuals:
|
|
def __init__(self, jaxpr, in_tree, out_tree, consts):
|
|
self.jaxpr = jaxpr
|
|
self.in_tree = in_tree
|
|
self.out_tree = out_tree
|
|
self.consts = consts
|
|
def __iter__(self):
|
|
return iter((self.jaxpr, self.in_tree, self.out_tree, self.consts))
|
|
def tree_flatten(self):
|
|
return self.consts, (self.jaxpr, self.in_tree, self.out_tree)
|
|
@classmethod
|
|
def tree_unflatten(cls, aux, consts):
|
|
jaxpr, in_tree, out_tree = aux
|
|
return cls(jaxpr, in_tree, out_tree, consts)
|
|
|
|
|
|
def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
|
"""Closure conversion utility, for use with higher-order custom derivatives.
|
|
|
|
To define custom derivatives such as with ``jax.custom_vjp(f)``, the target
|
|
function ``f`` must take, as formal arguments, all values involved in
|
|
differentiation. If ``f`` is a higher-order function, in that it accepts as an
|
|
argument a Python function ``g``, then values stored away in ``g``'s closure
|
|
will not be visible to the custom derivative rules, and attempts at AD
|
|
involving these values will fail. One way around this is to convert the
|
|
closure by extracting these values, and to pass them as explicit formal
|
|
arguments across the custom derivative boundary. This utility carries out that
|
|
conversion. More precisely, it closure-converts the function ``fun``
|
|
specialized to the types of the arguments given in ``example_args``.
|
|
|
|
When we refer here to "values in the closure" of ``fun``, we do not mean the
|
|
values that are captured by Python directly when ``fun`` is defined (e.g. the
|
|
Python objects in ``fun.__closure__``, if the attribute exists). Rather, we
|
|
mean values encountered during the execution of ``fun`` on ``example_args``
|
|
that determine its output. This may include, for instance, arrays captured
|
|
transitively in Python closures, i.e. in the Python closure of functions
|
|
called by ``fun``, the closures of the functions that they call, and so forth.
|
|
|
|
The function ``fun`` must be a pure function.
|
|
|
|
Example usage::
|
|
|
|
def minimize(objective_fn, x0):
|
|
converted_fn, aux_args = closure_convert(objective_fn, x0)
|
|
return _minimize(converted_fn, x0, *aux_args)
|
|
|
|
@partial(custom_vjp, nondiff_argnums=(0,))
|
|
def _minimize(objective_fn, x0, *args):
|
|
z = objective_fn(x0, *args)
|
|
# ... find minimizer x_opt ...
|
|
return x_opt
|
|
|
|
def fwd(objective_fn, x0, *args):
|
|
y = _minimize(objective_fn, x0, *args)
|
|
return y, (y, args)
|
|
|
|
def rev(objective_fn, res, g):
|
|
y, args = res
|
|
y_bar = g
|
|
# ... custom reverse-mode AD ...
|
|
return x0_bar, *args_bars
|
|
|
|
_minimize.defvjp(fwd, rev)
|
|
|
|
Args:
|
|
fun: Python callable to be converted. Must be a pure function.
|
|
example_args: Arrays, scalars, or (nested) standard Python
|
|
containers (tuples, lists, dicts, namedtuples, i.e., pytrees)
|
|
thereof, used to determine the types of the formal arguments to
|
|
``fun``. This type-specialized form of ``fun`` is the function
|
|
that will be closure converted.
|
|
|
|
Returns:
|
|
A pair comprising (i) a Python callable, accepting the same
|
|
arguments as ``fun`` followed by arguments corresponding to the
|
|
values hoisted from its closure, and (ii) a list of values hoisted
|
|
from the closure.
|
|
"""
|
|
flat_args, in_tree = tree_flatten(example_args)
|
|
in_avals = tuple(map(abstractify, flat_args))
|
|
if config.check_tracer_leaks.value:
|
|
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
|
|
else:
|
|
return _closure_convert_for_avals(fun, in_tree, in_avals)
|
|
|
|
def _maybe_perturbed(x: Any) -> bool:
|
|
# False if x can't represent an AD-perturbed value (i.e. a value
|
|
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
|
|
# See https://github.com/jax-ml/jax/issues/6415 for motivation.
|
|
x = core.full_lower(x)
|
|
if not isinstance(x, core.Tracer):
|
|
# If x is not a Tracer, it can't be perturbed.
|
|
return False
|
|
elif isinstance(x, pe.DynamicJaxprTracer):
|
|
# If x is a DynamicJaxprTracer then we're staging out; differentiation could
|
|
# happen later, but some types always have trivial tangents.
|
|
vspace = x.aval.to_tangent_aval()
|
|
return not (vspace is core.abstract_token or
|
|
getattr(vspace, 'dtype', None) == dtypes.float0)
|
|
elif not isinstance(x, ad.JVPTracer):
|
|
# If x is not a JVPTracer, recursively check its contents.
|
|
return any(_maybe_perturbed(attr) for name, attr in x._contents())
|
|
else:
|
|
return True # We can't be sure!
|
|
|
|
@cache()
|
|
def _closure_convert_for_avals(fun, in_tree, in_avals):
|
|
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
|
jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
|
out_tree = out_tree()
|
|
|
|
(closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
|
|
num_consts = len(hoisted_consts)
|
|
|
|
def converted_fun(*args_hconsts):
|
|
num_args = len(args_hconsts) - num_consts
|
|
args, hoisted_consts = split_list(args_hconsts, [num_args])
|
|
consts = merge(closure_consts, hoisted_consts)
|
|
all_args, in_tree2 = tree_flatten(tuple(args))
|
|
if in_tree != in_tree2:
|
|
msg = ("The inputs to the closure produced by closure_convert must have "
|
|
"the same Pytree structure as the example arguments passed when "
|
|
f"closure_convert was called. Expected {in_tree}, but got "
|
|
f"{in_tree2}")
|
|
raise TypeError(msg)
|
|
out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
|
|
return tree_unflatten(out_tree, out_flat)
|
|
|
|
return converted_fun, hoisted_consts
|
|
|
|
def partition_list(choice, lst):
|
|
out = [], []
|
|
which = [out[choice(elt)].append(elt) or choice(elt) for elt in lst]
|
|
def merge(l1, l2):
|
|
i1, i2 = iter(l1), iter(l2)
|
|
return [next(i2 if snd else i1) for snd in which]
|
|
return out, merge
|
|
|
|
def abstractify(x):
|
|
return core.raise_to_shaped(core.get_aval(x))
|
|
|
|
|
|
### Custom transposition
|
|
|
|
def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
|
|
linear_args):
|
|
"""Call a linear function, with a custom implementation for its transpose.
|
|
|
|
The `Haskell-like type signatures`_ of ``fun`` and ``fun_transpose`` are:
|
|
|
|
.. code-block:: haskell
|
|
|
|
fun :: r -> a -o b
|
|
fun_transpose :: r -> b -o a
|
|
|
|
where the ``-o`` arrow indicates a linear function, ``r`` is the
|
|
residual input type and ``a`` is the linear input type.
|
|
|
|
The functions ``fun`` and ``fun_transpose`` are coupled as
|
|
transposes of one another. Specifically, the transpose of a
|
|
``linear_call`` primitive is another ``linear_call`` to
|
|
``fun_transpose``, with ``fun`` as its custom transposition.
|
|
|
|
For example:
|
|
|
|
>>> def f(r, x):
|
|
... return x / r
|
|
|
|
>>> def t(r, t):
|
|
... return t / r
|
|
|
|
>>> def div_add(x, denom):
|
|
... return x + linear_call(f, t, denom, x)
|
|
|
|
>>> def transpose(f, x_example):
|
|
... def transposed(y):
|
|
... x, = jax.linear_transpose(f, x_example)(y)
|
|
... return x
|
|
... return transposed
|
|
|
|
>>> div_add(9., 3.)
|
|
Array(12., dtype=float32, weak_type=True)
|
|
|
|
>>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom
|
|
Array(24., dtype=float32, weak_type=True)
|
|
|
|
>>> transpose(lambda x: x + x / 3., 1.)(18.) # reference
|
|
Array(24., dtype=float32, weak_type=True)
|
|
|
|
The above definition of ``f`` illustrates the purpose of a residual
|
|
argument: division is linear in one of its inputs (the dividend
|
|
``x``) but not the other (the divisor ``r``).
|
|
|
|
As another example:
|
|
|
|
>>> def custom_id(x):
|
|
... def f(_, x): return x
|
|
... def t(_, t): return 7.
|
|
... return linear_call(f, t, (), x)
|
|
>>> custom_id(1.)
|
|
1.0
|
|
>>> transpose(custom_id, 1.)(1.)
|
|
7.0
|
|
>>> transpose(transpose(custom_id, 1.), 1.)(1.)
|
|
1.0
|
|
>>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.)
|
|
7.0
|
|
|
|
Args:
|
|
fun: a Python callable specifying a linear function. It should
|
|
take two arguments: one of "residual" inputs (type ``r``),
|
|
i.e. inputs in which the function is not necessarily linear, and
|
|
one of "linear" inputs (type ``a``). It should return output
|
|
whose components are linear in the linear input (type ``b``).
|
|
fun_transpose: a Python callable specifying a structurally linear
|
|
function that is the transpose of ``fun`` with respect to its
|
|
linear inputs. Its first argument is the same residual inputs
|
|
(``r``) as ``fun``. Its second argument is of type
|
|
``b``. Finally, its output is of type ``a`` and each of its
|
|
component are linear in its second argument (the ``b`` inputs).
|
|
residual_args: Argument in which ``fun`` and ``fun_transpose`` are
|
|
not necessarily linear. Not involved in transposition.
|
|
linear_args: Argument in which ``fun`` and ``fun_transpose`` are
|
|
linear and with respect to which the two are transposes.
|
|
|
|
Returns:
|
|
The call result, i.e. ``fun(residual_args, linear_args)``.
|
|
|
|
.. _Haskell-like type signatures: https://wiki.haskell.org/Type_signature
|
|
"""
|
|
operands_res, res_tree = tree_flatten(residual_args)
|
|
operands_lin, lin_tree = tree_flatten(linear_args)
|
|
|
|
f_in_tree = treedef_tuple((res_tree, lin_tree))
|
|
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
|
|
|
|
res_avals = map(abstractify, operands_res)
|
|
lin_avals = map(abstractify, operands_lin)
|
|
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
|
|
f_jaxpr = _close_jaxpr(f_jaxpr)
|
|
out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals)
|
|
|
|
t_in_tree = treedef_tuple((res_tree, out_tree()))
|
|
t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree)
|
|
|
|
t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals))
|
|
t_jaxpr = _close_jaxpr(t_jaxpr)
|
|
|
|
if t_out_tree() != lin_tree:
|
|
raise TypeError(
|
|
'transpose output pytree structure must match that of linear inputs, '
|
|
f'got output structure {t_out_tree()} '
|
|
f'and input structure {lin_tree}.')
|
|
|
|
out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin,
|
|
callee=f_jaxpr,
|
|
transpose=t_jaxpr,
|
|
num_callee_consts=len(f_consts),
|
|
num_transpose_consts=len(t_consts),
|
|
num_res=len(operands_res))
|
|
|
|
return tree_unflatten(out_tree(), out)
|
|
|
|
def _linear_call_impl(*args, callee, transpose, num_callee_consts,
|
|
num_transpose_consts, num_res):
|
|
del transpose
|
|
consts, _, operands_res, operands_lin = split_list(
|
|
args, [num_callee_consts, num_transpose_consts, num_res])
|
|
return core.eval_jaxpr(callee.jaxpr, (), *consts, *operands_res, *operands_lin)
|
|
|
|
def _linear_call_transpose_rule(cts, *args, callee, transpose,
|
|
num_callee_consts,
|
|
num_transpose_consts, num_res):
|
|
f_consts, t_consts, operands_res, operands_lin = split_list(
|
|
args, [num_callee_consts, num_transpose_consts, num_res])
|
|
_, _, cts_avals = split_list(
|
|
transpose.in_avals, [num_transpose_consts, num_res])
|
|
|
|
assert all(ad.is_undefined_primal(x) for x in operands_lin)
|
|
assert all(not ad.is_undefined_primal(x) for x in operands_res)
|
|
|
|
cts = [zeros_like_aval(a) if type(ct) is Zero else ct
|
|
for ct, a in zip(cts, cts_avals)]
|
|
|
|
cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts,
|
|
callee=transpose,
|
|
transpose=callee,
|
|
num_callee_consts=len(t_consts),
|
|
num_transpose_consts=len(f_consts),
|
|
num_res=len(operands_res))
|
|
|
|
return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out
|
|
|
|
def _linear_call_abstract_eval(*args, **kwargs):
|
|
return map(core.raise_to_shaped, kwargs['callee'].out_avals)
|
|
|
|
linear_call_p = core.Primitive('linear_call')
|
|
linear_call_p.multiple_results = True
|
|
linear_call_p.def_impl(_linear_call_impl)
|
|
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
|
|
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
|
|
xla.register_initial_style_primitive(linear_call_p)
|
|
mlir.register_lowering(linear_call_p, mlir.lower_fun(
|
|
_linear_call_impl, multiple_results=True))
|
|
|
|
|
|
# A stageable primitive that fails when evaluated
|
|
unreachable_p: core.Primitive = core.Primitive('unreachable')
|
|
unreachable_p.multiple_results = True
|
|
|
|
def unreachable_impl(*_, out_avals, exc_type, message):
|
|
del out_avals
|
|
raise exc_type(message)
|
|
|
|
# Evaluation raises an exception
|
|
unreachable_p.def_impl(unreachable_impl)
|
|
|
|
# Translation raises an exception
|
|
# TODO(frostig,mattjj): We have no good way to translate a function
|
|
# that errs. Since MLIR lowering over-approximates concrete evaluation,
|
|
# we err on MLIR lowering for the time being.
|
|
mlir.register_lowering(unreachable_p, unreachable_impl)
|
|
|
|
# Abstract evaluation proceeds without issue, to allow for staging
|
|
unreachable_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals)
|
|
|
|
def unreachable(*args, out_avals=None, exc_type=TypeError,
|
|
message='unreachable'):
|
|
"""Fail when evaluated concretely (but allow for staging).
|
|
|
|
This function allows one to assert an impossibility of
|
|
evaluation. It can be used to guarantee that evaluation does not
|
|
"reach" a certain point in the sense that it does not execute, but
|
|
it can nonetheless be staged out by JAX without error.
|
|
|
|
Args:
|
|
*args: The arbitrary pytree of arguments to the function.
|
|
out_avals: Optional specification of the output types of this
|
|
function invocation from the point of view of staging. If
|
|
``None``, these are chosen as equal to types of input arguments.
|
|
exc_type: Optional constructor for the Python exception raised if
|
|
evaluated.
|
|
message: Optional string message for the Python exception raised
|
|
if evaluated.
|
|
|
|
"""
|
|
if out_avals is None:
|
|
out_avals = tree_map(core.get_aval, args)
|
|
|
|
args_flat, in_tree = tree_flatten(args)
|
|
out_avals_flat, out_tree = tree_flatten(out_avals)
|
|
out = unreachable_p.bind(*args_flat, out_avals=out_avals_flat,
|
|
exc_type=exc_type, message=message)
|
|
return tree_unflatten(out_tree, out)
|
|
|
|
|
|
disallow_jvp = partial(
|
|
unreachable,
|
|
exc_type=TypeError,
|
|
message="can't apply forward-mode autodiff (jvp) to a custom_vjp function.")
|
|
|
|
|
|
def custom_vjp_by_custom_transpose(fun, fwd, bwd):
|
|
fun = custom_jvp(fun)
|
|
|
|
@fun.defjvp
|
|
def jvp(primals, tangents):
|
|
outs, residuals = fwd(*primals)
|
|
tan_out_types = tree_map(lambda o: core.get_aval(o).to_tangent_aval(), outs)
|
|
tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types))
|
|
tan_fn.def_transpose(bwd)
|
|
return outs, tan_fn(tan_out_types, residuals, tangents)
|
|
|
|
return fun
|
|
|
|
|
|
# TODO(mattjj): remove these stubs, which exist to avoid breaking internal users
|
|
custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr")
|
|
|
|
|
|
# The following is a helper for optimizing the behavior of custom_vjp when used
|
|
# under remat. This is really only useful when the `fwd` function to custom_vjp
|
|
# executes a black box kernel. Otherwise, DCE will perform this optimization
|
|
# automatically.
|
|
#
|
|
# TODO(dfm): Eventually this should probably be the default behavior for
|
|
# custom_vjp, if we can make it so that it is a no-op for most cases. Right now,
|
|
# it is written in "initial-style" so it doesn't support eager mode. This was
|
|
# a reasonable compromise when written because it made the implementation
|
|
# simpler, but it would be worth revisiting this.
|
|
def optimize_remat_of_custom_vjp_fwd(
|
|
fun: Callable[..., ReturnValue],
|
|
fwd: Callable[..., tuple[ReturnValue, Any]],
|
|
nondiff_argnums: Sequence[int] = (),
|
|
symbolic_zeros: bool = False,
|
|
) -> Callable[..., tuple[ReturnValue, Any]]:
|
|
if symbolic_zeros:
|
|
# TODO(dfm): This probably shouldn't be too hard to support.
|
|
raise NotImplementedError(
|
|
"remat optimization for custom_vjp does not support symbolic zeros")
|
|
|
|
@wraps(fwd)
|
|
def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
|
|
# TODO(dfm): This initial logic is duplicated from custom_vjp.__call__
|
|
# above and it would be good to consolidate it.
|
|
primal_name = getattr(fun, "__name__", str(fun))
|
|
fwd_name = getattr(fwd, "__name__", str(fwd))
|
|
# Note: we use `fun` instead of `fwd` here for consistency with
|
|
# custom_vjp.__call__ above.
|
|
args = resolve_kwargs(fun, args, kwargs)
|
|
if nondiff_argnums:
|
|
for i in nondiff_argnums: _check_for_tracers(args[i])
|
|
nondiff_argnums_ = set(nondiff_argnums)
|
|
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums_]
|
|
f_, dyn_args = argnums_partial(lu.wrap_init(fun), dyn_argnums,
|
|
args, require_static_args_hashable=False)
|
|
fwd_, _ = argnums_partial(lu.wrap_init(fwd), dyn_argnums, args,
|
|
require_static_args_hashable=False)
|
|
else:
|
|
f_, dyn_args = lu.wrap_init(fun), args
|
|
fwd_ = lu.wrap_init(fwd)
|
|
args_flat, in_tree = tree_flatten(dyn_args)
|
|
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
|
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
|
|
in_tree, out_type)
|
|
flat_fwd = _fix_fwd_args(flat_fwd)
|
|
|
|
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
|
fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals)
|
|
fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr))
|
|
prim_tree, res_tree = out_trees()
|
|
num_res = res_tree.num_leaves
|
|
|
|
if fwd_jaxpr.effects:
|
|
raise NotImplementedError(
|
|
"remat optimization for custom_vjp does not support forward "
|
|
f"functions with side effects, but {fwd_name} has the following "
|
|
f"effects: {fwd_jaxpr.effects}")
|
|
|
|
@pe._memoize
|
|
def fun_jaxpr_thunk():
|
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
|
return jaxpr, consts
|
|
|
|
out_flat = remat_opt_p.bind(*consts, *args_flat,
|
|
num_consts=len(consts),
|
|
num_res=num_res,
|
|
fwd_jaxpr=fwd_jaxpr,
|
|
fun_jaxpr_thunk=fun_jaxpr_thunk)
|
|
res, out_flat = split_list(out_flat, [num_res])
|
|
out_tree = treedef_tuple((prim_tree, res_tree))
|
|
return tree_unflatten(out_tree, (*out_flat, *res))
|
|
|
|
return wrapped_fwd
|
|
|
|
@lu.transformation
|
|
def _fix_fwd_args(*args):
|
|
args = [(x, True) for x in args]
|
|
args = [x for pair in args for x in pair]
|
|
yield (yield args, {})
|
|
|
|
def _remat_opt_impl(
|
|
*args,
|
|
num_consts: int,
|
|
num_res: int,
|
|
fwd_jaxpr: core.ClosedJaxpr,
|
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
|
):
|
|
del num_consts, num_res, fun_jaxpr_thunk # unused
|
|
return core.jaxpr_as_fun(fwd_jaxpr)(*args)
|
|
|
|
def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
|
|
del args
|
|
return fwd_jaxpr.out_avals, fwd_jaxpr.effects
|
|
|
|
def _remat_opt_vmap(
|
|
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims,
|
|
*,
|
|
num_consts: int,
|
|
num_res: int,
|
|
fwd_jaxpr: core.ClosedJaxpr,
|
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
|
):
|
|
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
|
else x for x, d in zip(args, in_dims)]
|
|
|
|
in_batched = [d is not not_mapped for d in in_dims]
|
|
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
|
|
fwd_jaxpr, axis_size, in_batched, False,
|
|
axis_name, spmd_axis_name, main_type)
|
|
extra_consts = batched_fwd_jaxpr.consts
|
|
batched_fwd_jaxpr = pe.close_jaxpr(
|
|
pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr))
|
|
out_dims = [0 if b else not_mapped for b in out_batched]
|
|
|
|
_, prim_batched = split_list(in_batched, [num_consts])
|
|
|
|
@pe._memoize
|
|
def batched_fun_jaxpr_thunk():
|
|
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
|
|
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
|
|
fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name,
|
|
main_type)
|
|
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts
|
|
|
|
batched_outs = remat_opt_p.bind(*extra_consts, *args,
|
|
num_consts=num_consts + len(extra_consts),
|
|
num_res=num_res,
|
|
fwd_jaxpr=batched_fwd_jaxpr,
|
|
fun_jaxpr_thunk=batched_fun_jaxpr_thunk)
|
|
|
|
return batched_outs, out_dims
|
|
|
|
def _remat_opt_jvp(
|
|
primals,
|
|
tangents,
|
|
*,
|
|
num_consts: int,
|
|
num_res: int,
|
|
fwd_jaxpr: core.ClosedJaxpr,
|
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
|
):
|
|
consts, primals = split_list(primals, [num_consts])
|
|
consts_dot, tangents = split_list(tangents, [num_consts])
|
|
# Tangents must be instantated in case we end up DCEing later.
|
|
tangents = map(ad.instantiate_zeros, tangents)
|
|
consts_nz = [not isinstance(t, Zero) for t in consts_dot]
|
|
consts_dot = [c for nz, c in zip(consts_nz, consts_dot) if nz]
|
|
in_nz = consts_nz + [True] * len(tangents)
|
|
fwd_jaxpr_jvp_, out_nz = ad.jvp_jaxpr(fwd_jaxpr, in_nz, True)
|
|
num_out = len(out_nz) - num_res
|
|
fwd_jaxpr_jvp_ = ad.rearrange_binders(
|
|
fwd_jaxpr_jvp_, [num_consts, len(primals)],
|
|
[len(consts_dot), len(tangents)], [num_res, num_out], [num_res, num_out])
|
|
fwd_jaxpr_jvp = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr_jvp_.jaxpr))
|
|
|
|
@pe._memoize
|
|
def fun_jvp_jaxpr_thunk():
|
|
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
|
|
in_nz = [True] * len(primals)
|
|
fun_jvp_jaxpr, _ = ad.jvp_jaxpr(fun_jaxpr, in_nz, True)
|
|
return fun_jvp_jaxpr.jaxpr, fun_jvp_jaxpr.consts
|
|
|
|
new_num_consts = len(fwd_jaxpr_jvp_.consts) + num_consts + len(consts_dot)
|
|
outs = remat_opt_p.bind(*fwd_jaxpr_jvp_.consts, *consts, *consts_dot,
|
|
*primals, *tangents, num_consts=new_num_consts,
|
|
num_res=2 * num_res, fwd_jaxpr=fwd_jaxpr_jvp,
|
|
fun_jaxpr_thunk=fun_jvp_jaxpr_thunk)
|
|
res, res_dot, outs, outs_dot = split_list(outs, [num_res, num_res, num_out])
|
|
return (*res, *outs), (*res_dot, *outs_dot)
|
|
|
|
def _remat_opt_transpose(
|
|
cts, *args,
|
|
num_consts: int,
|
|
num_res: int,
|
|
fwd_jaxpr: core.ClosedJaxpr,
|
|
fun_jaxpr_thunk: Callable[[], core.ClosedJaxpr],
|
|
):
|
|
# TODO(dfm): It shouldn't be too hard to implement this as needed in the
|
|
# future.
|
|
raise NotImplementedError(
|
|
"remat optimization for custom_vjp does not support higher-order AD")
|
|
|
|
def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
|
|
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
|
|
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
|
|
if any(used_res):
|
|
# If any of the residuals are used, we still need to run fwd at this point,
|
|
# but we may end up DCEing again in the future, so we must instantiate all
|
|
# the input primals.
|
|
instantiate = [False] * eqn.params["num_consts"]
|
|
instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"])
|
|
new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs,
|
|
instantiate=instantiate)
|
|
assert not new_jaxpr.constvars
|
|
closed_jaxpr = pe.close_jaxpr(new_jaxpr)
|
|
invars = [v for used, v in zip(used_ins, eqn.invars) if used]
|
|
new_params = dict(eqn.params)
|
|
new_num_consts = sum(split_list(used_ins, [eqn.params["num_consts"]])[0])
|
|
new_params["num_consts"] = new_num_consts
|
|
new_params["fwd_jaxpr"] = closed_jaxpr
|
|
new_params["num_res"] = sum(used_res)
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
invars, outvars, remat_opt_p, new_params, closed_jaxpr.effects,
|
|
eqn.source_info, eqn.ctx)
|
|
return used_ins, new_eqn
|
|
else:
|
|
# If none of the residuals are used, we run the primal computation instead.
|
|
# At this point we drop this custom DCE behavior, but since the primal might
|
|
# have different consts than fwd, we build a new JaxprEqn with a closed_call
|
|
# primitive.
|
|
fun_jaxpr, consts = eqn.params["fun_jaxpr_thunk"]()
|
|
new_jaxpr, used_consts, used_ins = pe.dce_jaxpr_consts(fun_jaxpr, used_prims)
|
|
consts = [c for used, c in zip(used_consts, consts) if used]
|
|
closed_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
|
|
_, invars = split_list(eqn.invars, [eqn.params["num_consts"]])
|
|
invars = [v for used, v in zip(used_ins, invars) if used]
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
invars, outvars, core.closed_call_p, dict(call_jaxpr=closed_jaxpr),
|
|
closed_jaxpr.effects, eqn.source_info, eqn.ctx)
|
|
used_ins = [False] * eqn.params["num_consts"] + used_ins
|
|
return used_ins, new_eqn
|
|
|
|
remat_opt_p = core.Primitive("remat_opt")
|
|
remat_opt_p.multiple_results = True
|
|
remat_opt_p.def_impl(_remat_opt_impl)
|
|
remat_opt_p.def_effectful_abstract_eval(_remat_opt_abstract_eval)
|
|
xla.register_initial_style_primitive(remat_opt_p)
|
|
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
|
|
_remat_opt_impl, multiple_results=True))
|
|
batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
|
|
batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None)
|
|
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
|
|
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
|
|
pe.dce_rules[remat_opt_p] = _remat_opt_dce
|