custom_vjp symbolic zeros support, take two

This change re-introduces symbolic zero support for `custom_vjp`.

This time:

* The forward rule API is slightly different, accepting two-field
  records at pytree leaves rather than pairs.

* In the default setting where symbolic_zeros is not set, there are no
  new requirements from pytree node definitions that are involved in
  the primal arguments. This avoids any change in behavior on the
  default path. In particular, custom pytree node definitions that
  aren't completely polymorphic in unflattening can remain as is.

* There is an additional test involving a custom pytree node.
This commit is contained in:
Roy Frostig 2023-03-24 14:42:19 -07:00
parent 0e549ac4be
commit d51b8e6839
10 changed files with 443 additions and 86 deletions

View File

@ -932,7 +932,8 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
fwd_jaxpr_thunk, num_consts, bwd, out_trees):
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
symbolic_zeros):
err_vals, err_tree = jtu.tree_flatten(in_err)
fun = lu.wrap_init(
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
@ -940,15 +941,17 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun)
@lu.wrap_init
def fwd(*xs):
def fwd(*args):
# TODO(lenamartens, sharadmv): why not checkify here?
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk()
xs, zeros = args[::2], args[1::2]
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
xs_without_consts = xs[num_consts:]
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)
fwd, fwd_out_tree = flatten_fun_output(fwd)
all_outs = custom_derivatives.custom_vjp_call_p.bind(
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees)
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
if fst:
err_and_out_tree, _ = out_metadata

View File

@ -512,7 +512,8 @@ class Trace(Generic[TracerType]):
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)

View File

@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from functools import update_wrapper, reduce, partial
import inspect
from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any)
from typing import (
Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar)
from jax._src import core
from jax._src import custom_api_util
@ -23,8 +25,8 @@ from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import traceback_util
from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval,
stop_gradient_p)
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.config import config
from jax._src.core import raise_to_shaped
@ -162,12 +164,13 @@ class custom_jvp(Generic[ReturnValue]):
and the second element is the tangent output. Elements of the input and
output tuples may be arrays or any nested tuples/lists/dicts thereof.
symbolic_zeros: boolean, indicating whether the rule should be passed
objects representing static symbolic zeros in its tangent tuple
argument; otherwise, only standard JAX types (e.g. array-likes) are
passed. Setting this option to True allows a JVP rule to detect whether
certain inputs are not involved in differentiation, but at the cost of
needing special handling for these objects (which e.g. can't be passed
into jax.numpy functions). Default False.
objects representing static symbolic zeros in its tangent argument in
correspondence with unperturbed values; otherwise, only standard JAX
types (e.g. array-likes) are passed. Setting this option to ``True``
allows a JVP rule to detect whether certain inputs are not involved in
differentiation, but at the cost of needing special handling for these
objects (which e.g. can't be passed into jax.numpy functions). Default
``False``.
Returns:
None.
@ -479,12 +482,15 @@ class custom_vjp(Generic[ReturnValue]):
self.nondiff_argnums = nondiff_argnums
self.fwd: Optional[Callable[..., Tuple[ReturnValue, Any]]] = None
self.bwd: Optional[Callable[..., Tuple[Any, ...]]] = None
self.symbolic_zeros = False
__getattr__ = custom_api_util.forward_attr
def defvjp(self,
fwd: Callable[..., Tuple[ReturnValue, Any]],
bwd: Callable[..., Tuple[Any, ...]]) -> None:
bwd: Callable[..., Tuple[Any, ...]],
symbolic_zeros: bool = False,
) -> None:
"""Define a custom VJP rule for the function represented by this instance.
Args:
@ -505,6 +511,38 @@ class custom_vjp(Generic[ReturnValue]):
function, and the tuple elements may be arrays or nested
tuples/lists/dicts thereof so as to match the structure of the primal
input arguments.
symbolic_zeros: boolean, determining whether to indicate symbolic zeros
to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom
derivative rules to detect when certain inputs, and when certain
output cotangents, are not involved in differentiation. If ``True``:
* ``fwd`` must accept, in place of each leaf value ``x`` in the pytree
comprising an argument to the original function, an object with two
attributes instead: ``value`` and ``perturbed``. The ``value`` field
is the original primal argument, and ``perturbed`` is a boolean.
The ``perturbed`` bit indicates whether the argument is involved in
differentiation (i.e., if it is ``False``, then the corresponding
Jacobian "column" is zero).
* ``bwd`` will be passed objects representing static symbolic zeros in
its cotangent argument in correspondence with unperturbed values;
otherwise, only standard JAX types (e.g. array-likes) are passed.
Setting this option to ``True`` allows these rules to detect whether
certain inputs and outputs are not involved in differentiation, but at
the cost of special handling. For instance:
* The signature of ``fwd`` changes, and the objects it is passed cannot
be output from the rule directly.
* The ``bwd`` rule is passed objects that are not entirely array-like,
and that cannot be passed to most ``jax.numpy`` functions.
* Any custom pytree nodes involved in the primal function's arguments
must accept, in their unflattening functions, the two-field record
objects that are given as input leaves to the ``fwd`` rule.
Default ``False``.
Returns:
None.
@ -526,6 +564,7 @@ class custom_vjp(Generic[ReturnValue]):
"""
self.fwd = fwd
self.bwd = bwd
self.symbolic_zeros = symbolic_zeros
@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
@ -533,7 +572,7 @@ class custom_vjp(Generic[ReturnValue]):
if not self.fwd or not self.bwd:
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs)
if config.jax_enable_custom_vjp_by_custom_transpose:
if self.nondiff_argnums:
@ -557,14 +596,38 @@ class custom_vjp(Generic[ReturnValue]):
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
out_type)
flat_fwd, out_trees = _flatten_fwd(fwd, self.symbolic_zeros, primal_name,
fwd_name, in_tree, out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
*args_flat, out_trees=out_trees,
symbolic_zeros=self.symbolic_zeros)
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
return tree_unflatten(out_tree, out_flat)
@dataclasses.dataclass
class CustomVJPPrimal:
value: Any
perturbed: bool
def custom_vjp_primal_tree_values(tree):
"""Strips away perturbation information from forward rule arguments.
This is a helper function for user with the ``symbolic_zeros`` option to
the ``defvjp`` method of a ``custom_vjp``-decorated function.
In ``symbolic_zeros`` mode, the custom forward rule receives arguments
whose pytree leaves are records with a ``value`` attribute that carries
the primal argument. This is a way to convert such argument trees back to
their original form, replacing each such record with its carried value at
each leaf.
"""
def value(leaf):
if type(leaf) is not CustomVJPPrimal:
raise TypeError(f"unexpected leaf type {type(leaf)}")
return leaf.value
return tree_map(value, tree)
def _check_for_tracers(x):
for leaf in tree_leaves(x):
if isinstance(leaf, core.Tracer):
@ -578,7 +641,12 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)
@lu.transformation_with_aux
def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args):
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
*args):
if symbolic_zeros:
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
else:
args = args[::2]
py_args = tree_unflatten(in_tree, args)
pair_out = yield py_args, {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
@ -671,7 +739,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive
def bind(self, fun, fwd, bwd, *args, out_trees):
def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = process_env_traces(
@ -681,7 +749,8 @@ class CustomVJPCallPrimitive(core.CallPrimitive):
tracers = map(top_trace.full_raise, args) # type: ignore
bwd_ = lambda *args: bwd(*args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
out_trees=out_trees)
out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if fst:
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
@ -747,31 +816,34 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: Callable, out_trees: Callable, num_consts: int):
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
raise ad.CustomVJPException()
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers!
out_tree, res_tree = out_trees()
zeros = [type(t) is not Zero for t in args_dot]
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers!
_, res_tree = out_trees()
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
args_dot = map(ad.replace_float0s, args, args_dot)
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out)
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: Callable, out_trees: Callable, num_consts: int):
def _custom_vjp_call_jaxpr_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
@ -784,8 +856,8 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
out_dims2 = []
@pe._memoize
def batched_fwd_jaxpr_thunk():
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
def batched_fwd_jaxpr_thunk(*zeros):
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name,
main_type)
@ -794,17 +866,20 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
fwd_args_batched, main_type, spmd_axis_name)
batched_bwd = batching.batch_custom_vjp_bwd(
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type,
spmd_axis_name)
batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
out_trees=out_trees, num_consts=num_consts)
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None)
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
_custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
_custom_vjp_call_jaxpr_vmap, None)
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)

View File

@ -27,9 +27,9 @@ from jax.tree_util import (tree_flatten, tree_unflatten,
from jax._src import core
from jax._src import source_info_util
from jax._src.ad_util import (
add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
zeros_like_p, Zero, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros)
add_jaxvals, add_jaxvals_p, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval,
zeros_like_jaxval, zeros_like_p)
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
@ -387,16 +387,21 @@ class JVPTrace(Trace):
def post_process_custom_jvp_call(self, out_tracers, _):
raise CustomJVPException()
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees):
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees,
symbolic_zeros):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
tangents_in = map(instantiate_zeros, tangents_in)
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
out_tree, res_tree = out_trees()
fwd_in = [(core.full_lower(p), type(t) is not Zero)
for p, t in zip(primals_in, tangents_in)]
fwd_in = [x for pair in fwd_in for x in pair] # flatten
res_and_primals_out = fwd.call_wrapped(*fwd_in)
_, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
tangents_in = map(instantiate_zeros, tangents_in)
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out)
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
@ -745,10 +750,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__):
"function.")
custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals,
symbolic_zeros):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
if symbolic_zeros:
cts_out = map(replace_internal_symbolic_zeros, cts_out)
else:
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
cts_in = bwd(*res, *cts_out)
cts_in = map(replace_rule_output_symbolic_zeros, cts_in)
return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose

View File

@ -25,17 +25,18 @@ import jax
from jax.config import config
from jax._src import core
from jax._src import source_info_util
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src import linear_util as lu
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero, SymbolicZero,
replace_rule_output_symbolic_zeros, instantiate)
from jax._src import linear_util as lu
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)
from jax._src.interpreters import partial_eval as pe
Array = Any
map, unsafe_map = safe_map, map
@ -478,16 +479,19 @@ class BatchTrace(Trace):
return map(partial(BatchTracer, trace), vals, dims, srcs)
return vals, todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): # pytype: disable=signature-mismatch
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees,
symbolic_zeros): # pytype: disable=signature-mismatch
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
if d is not not_mapped}
fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]]
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims)
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
out_dims2, in_dims, self.main.trace_type,
self.spmd_axis_name)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
_, res_tree = out_trees()
@ -784,8 +788,14 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests,
main_type, spmd_axis_name):
def new_bwd(*args):
in_dims_ = in_dims() if callable(in_dims) else in_dims
args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval))
if type(x) is SymbolicZero else x
for x, dim in zip(args, in_dims_)]
in_dims_ = [None if type(x) is SymbolicZero else d
for x, d in zip(args, in_dims_)]
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd))
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type,
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type,
spmd_axis_name)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
out_dim_dests)
@ -802,7 +812,7 @@ def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
# TODO(mattjj): dedup with matchaxis
if isinstance(x, Zero):
if isinstance(x, (Zero, SymbolicZero)):
if src == dst:
return x
elif type(src) == type(dst) == int:

View File

@ -497,14 +497,16 @@ class JaxprTrace(Trace['JaxprTracer']):
for t in out_tracers: t.recipe = eqn
return out_tracers
def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees):
def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees,
symbolic_zeros):
# TODO(mattjj): after old remat is deleted, make this method trivial.
# Because we instantiate all tracers, in_knowns is all False.
tracers = map(self.instantiate_const_abstracted, tracers)
in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers])
f = trace_to_subjaxpr_nounits(f, self.main, True)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees)
out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
out_knowns, out_avals, jaxpr, env = aux()
out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
res_tracers = map(self.new_instantiated_const, res)
@ -514,8 +516,9 @@ class JaxprTrace(Trace['JaxprTracer']):
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
@_memoize
def fwd_jaxpr_thunk():
fwd_ = trace_to_subjaxpr_nounits(fwd, self.main, True)
def fwd_jaxpr_thunk(*zeros):
fwd_ = _interleave_fun(fwd, zeros)
fwd_ = trace_to_subjaxpr_nounits(fwd_, self.main, True)
fwd_, aux = partial_eval_wrapper_nounits(
fwd_, tuple(in_knowns), tuple(in_avals))
with core.new_sublevel():
@ -532,7 +535,8 @@ class JaxprTrace(Trace['JaxprTracer']):
dict(fun_jaxpr=closed_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(res) + len(env),
bwd=bwd, out_trees=out_trees),
bwd=bwd, out_trees=out_trees,
symbolic_zeros=symbolic_zeros),
jaxpr.effects, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -1959,23 +1963,29 @@ class DynamicJaxprTrace(core.Trace):
def post_process_custom_jvp_call(self, out_tracers, _):
assert False # unreachable
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
in_avals = [t.aval for t in tracers]
with core.new_sublevel():
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
main_ = ref(self.main)
fwd_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2])
@_memoize
def fwd_jaxpr_from_zeros(*zeros):
fwd_ = _interleave_fun(fwd, zeros)
return trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals)[::2]
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
dict(fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
fwd_jaxpr_thunk=fwd_jaxpr_from_zeros,
num_consts=len(consts),
bwd=bwd, out_trees=out_trees),
bwd=bwd, out_trees=out_trees,
symbolic_zeros=symbolic_zeros),
fun_jaxpr.effects,
source_info_util.current())
self.frame.add_eqn(eqn)
@ -2024,18 +2034,23 @@ class DynamicJaxprTrace(core.Trace):
custom_staging_rules: Dict[Primitive, Callable] = {}
def _memoize(thunk):
cell = []
@lu.transformation
def _interleave_fun(every_others, *args, **kwargs):
args_ = [x for pair in zip(args, every_others) for x in pair]
yield (yield (args_, kwargs))
def _memoize(fn):
cells = {}
saved_state = [core.thread_local_state.trace_state.copy()]
def memoized():
if not cell:
def memoized(*args):
if args not in cells:
prev_state = core.thread_local_state.trace_state
core.thread_local_state.trace_state = saved_state.pop()
core.thread_local_state.trace_state = saved_state[0]
try:
cell.append(thunk())
cells[args] = fn(*args)
finally:
core.thread_local_state.trace_state = prev_state
return cell[0]
return cells[args]
return memoized
# TODO(mattjj): remove this DebugInfo and helper functions, replace with

View File

@ -918,10 +918,11 @@ class MapTrace(core.Trace):
return self.process_primitive(fake_primitive, tracers, {})
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees):
out_trees, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: primitive.bind(fun, fwd, bwd, *args,
out_trees=out_trees, **kwargs),
lambda *args, **kwargs: primitive.bind(
fun, fwd, bwd, *args, out_trees=out_trees,
symbolic_zeros=symbolic_zeros, **kwargs),
(primitive, fun, fwd, bwd))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})

View File

@ -27,6 +27,8 @@ from jax._src.custom_derivatives import (
custom_vjp as custom_vjp,
custom_vjp_call_p as custom_vjp_call_p,
custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p,
custom_vjp_primal_tree_values as custom_vjp_primal_tree_values,
CustomVJPPrimal as CustomVJPPrimal,
linear_call as linear_call,
)

View File

@ -1254,11 +1254,12 @@ class TensorFlowTrace(core.Trace):
def post_process_custom_jvp_call(self, out_tracers, _):
assert False # unreachable assuming jax2tf runs with clean trace state
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del fwd, bwd, out_trees # Unused.
del fwd, bwd, out_trees, symbolic_zeros # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_vjp_call(self, out_tracers, _):

View File

@ -7415,10 +7415,10 @@ class CustomJVPTest(jtu.JaxTestCase):
primal_outs, tangent_outs = run(primal_ins, tangent_ins)
primal_out1, primal_out2 = primal_outs
tangent_out1, tangent_out2 = tangent_outs
scalar_dtype = jax.Array if maybe_jit or maybe_vmap else float
self.assertIsInstance(primal_out1, scalar_dtype)
scalar_type = jax.Array if maybe_jit or maybe_vmap else float
self.assertIsInstance(primal_out1, scalar_type)
self.assertAllClose(primal_out1, 5.)
self.assertIsInstance(tangent_out1, scalar_dtype)
self.assertIsInstance(tangent_out1, scalar_type)
self.assertAllClose(tangent_out1, 91.)
self.assertIsInstance(primal_out2, jax.Array)
self.assertArraysAllClose(primal_out2, jnp.array([7., 9.]))
@ -8496,6 +8496,245 @@ class CustomVJPTest(jtu.JaxTestCase):
f_vjp(jnp.array([3.]))
f_vjp(jnp.array([3.])) # doesn't crash
def test_symbolic_zero_custom_vjp_basic(self):
ZERO = custom_derivatives_public.SymbolicZero
@jax.custom_vjp
def f(x, y, z):
return x, x
def fwd(x, y, z):
self.assertTrue(x.perturbed)
self.assertFalse(y.perturbed)
self.assertFalse(z.perturbed)
return (x.value, x.value), None
def fwd_all(x, y, z):
self.assertTrue(x.perturbed)
self.assertTrue(y.perturbed)
self.assertTrue(z.perturbed)
return (x.value, x.value), None
def bwd_all(_, g):
x1, x2 = g
self.assertFalse(type(x1) is ZERO)
self.assertFalse(type(x2) is ZERO)
return x1, x1, x2
def bwd_fst(_, g):
x1, x2 = g
self.assertFalse(type(x1) is ZERO)
self.assertIs(type(x2), ZERO)
return x1, x1, x2
def bwd_snd(_, g):
x1, x2 = g
self.assertIs(type(x1), ZERO)
self.assertFalse(type(x2) is ZERO)
return x1, x1, x2
x, y, z = 4., 5., 6.
i = np.array(7, np.int32)
zero = np.array(0.)
f.defvjp(fwd, bwd_all, symbolic_zeros=True)
h = jax.jit(f)
jax.jacrev(h)(x, y, z)
jax.jacrev(lambda x: h(x, y, z))(x)
jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i)
f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True)
fst_f = lambda *xs: f(*xs)[0]
_, vjp = jax.vjp(fst_f, x, y, z)
_, _, gz = vjp(x)
self.assertArraysAllClose(gz, zero)
f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True)
snd_f = lambda *xs: f(*xs)[1]
_, vjp = jax.vjp(snd_f, x, y, z)
gx, gy, _ = vjp(x)
self.assertArraysAllClose(gx, zero)
self.assertArraysAllClose(gy, zero)
f.defvjp(fwd, bwd_snd, symbolic_zeros=True)
_, vjp = jax.vjp(lambda x: snd_f(x, y, z), x)
gx, = vjp(x)
self.assertArraysAllClose(gx, zero)
@parameterized.named_parameters(
('jit_vmap', True, True),
('jit', True, False),
('vmap', False, True),
('', False, False),
)
def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap):
# below:
# * static_scalar will be static in and out
# * static_array will be static in, but dynamic out
# * dyn_scalar and dyn_array will be dynamic in and out
ZERO = custom_derivatives_public.SymbolicZero
def f(static_scalar, static_array, dyn_scalar, dyn_array):
out1 = static_scalar + dyn_scalar
out2 = static_array + dyn_array
return static_scalar, static_array, out1, out2
def _pack(x):
return lax.broadcast(x, (1,))
def _unpack(x):
(x,) = x
return x
def _vmap(fun):
def _fun(*args):
args = tree_util.tree_map(_pack, args)
out = jax.vmap(fun)(*args)
out = tree_util.tree_map(_unpack, out)
return out
return _fun
f = jax.custom_vjp(f)
def fwd(*args):
xs, pert = [x.value for x in args], [x.perturbed for x in args]
self.assertFalse(pert[0])
self.assertFalse(pert[1])
self.assertTrue(pert[2])
self.assertTrue(pert[3])
return f(*xs), xs
def bwd(res, g):
static_scalar, *_ = res
t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g
self.assertIs(type(t_static), ZERO)
self.assertFalse(type(t_static_arr) is ZERO)
self.assertFalse(type(t_dyn_scalar) is ZERO)
self.assertFalse(type(t_dyn_array) is ZERO)
self.assertEqual(t_static.shape, ())
self.assertEqual(t_static_arr.shape, (2,))
return (static_scalar + 90,
t_static_arr + 91,
t_dyn_scalar + 92,
t_dyn_array + 93)
f.defvjp(fwd, bwd, symbolic_zeros=True)
def g(dyn_scalar, dyn_array):
if maybe_vmap:
f_ = _vmap(f)
else:
f_ = f
outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array)
return outs[1:]
def run(primal_ins, cotangent_outs):
primal_outs, vjp = jax.vjp(g, *primal_ins)
cotangent_ins = vjp(cotangent_outs)
return primal_outs, cotangent_ins
if maybe_jit:
run = jax.jit(run)
scalar_type = jax.Array if maybe_jit or maybe_vmap else float
primal_ins = (4., jnp.array([5., 6.]))
cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.]))
primal_outs, cotangent_ins = run(primal_ins, cotangent_outs)
primal_out1, primal_out2, primal_out3 = primal_outs
self.assertIsInstance(primal_out1, jax.Array)
self.assertAllClose(primal_out1, jnp.array([2., 3.]))
self.assertIsInstance(primal_out2, scalar_type)
self.assertAllClose(primal_out2, 5.)
self.assertIsInstance(primal_out3, jax.Array)
self.assertAllClose(primal_out3, jnp.array([7., 9.]))
ct_in1, ct_in2 = cotangent_ins
self.assertIsInstance(ct_in1, scalar_type)
self.assertAllClose(ct_in1, 99.)
self.assertIsInstance(ct_in2, jax.Array)
self.assertArraysAllClose(ct_in2, jnp.array([101., 102.]))
def test_symbolic_zero_custom_vjp_vmap_output(self):
@jax.custom_vjp
def f(x, y):
return x, y
def fwd(x, y):
self.assertTrue(x.perturbed)
self.assertFalse(y.perturbed)
return f(x.value, y.value), None
def bwd(_, g):
_, ct_y = g
self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero)
return g
f.defvjp(fwd, bwd, symbolic_zeros=True)
jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3))
def test_symbolic_zero_custom_vjp_custom_pytree(self):
tree_values = custom_derivatives_public.custom_vjp_primal_tree_values
@tree_util.register_pytree_node_class
class Box:
def __init__(self_, strict, val):
if strict:
# make sure we aren't getting special arguments that should only
# come up when symbolic_zeros is True
self.assertFalse(hasattr(val, 'perturbed'))
self_.strict = strict
self_.x = val
def tree_flatten(self_):
return [self_.x], self_.strict
@classmethod
def tree_unflatten(cls, strict, xs):
x, = xs
return cls(strict, x)
x, y = Box(False, jnp.array(72.)), jnp.array(73.)
@jax.custom_vjp
def f(box, y):
return box.x * y
def fwd0(box, y):
self.assertTrue(box.x.perturbed)
self.assertFalse(y.perturbed)
box, y = map(tree_values, [box, y])
return f(box, y), (box, y)
def bwd0(res, g):
box, y = res
return y * g, box.x * g
def fwd1(box, y):
self.assertFalse(box.x.perturbed)
self.assertTrue(y.perturbed)
box, y = map(tree_values, [box, y])
return f(box, y), (box, y)
def bwd1(res, g):
box, y = res
return y * g, box.x * g
f.defvjp(fwd0, bwd0, symbolic_zeros=True)
jax.grad(f, argnums=0)(x, y)
f.defvjp(fwd1, bwd1, symbolic_zeros=True)
jax.grad(f, argnums=1)(x, y)
def fwd_strict(box, y):
return f(box, y), (box, y)
def bwd_strict(res, g):
box, y = res
return y * g, box.x * g
f.defvjp(fwd_strict, bwd_strict)
jax.grad(f)(x, y)
def transpose_unary(f, x_example):
def transposed(y):