diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 167110b95..829907058 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -932,7 +932,8 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts, error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, - fwd_jaxpr_thunk, num_consts, bwd, out_trees): + fwd_jaxpr_thunk, num_consts, bwd, out_trees, + symbolic_zeros): err_vals, err_tree = jtu.tree_flatten(in_err) fun = lu.wrap_init( functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, @@ -940,15 +941,17 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr, fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun) @lu.wrap_init - def fwd(*xs): + def fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() + xs, zeros = args[::2], args[1::2] + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) fwd, fwd_out_tree = flatten_fun_output(fwd) all_outs = custom_derivatives.custom_vjp_call_p.bind( - fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees) + fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree) if fst: err_and_out_tree, _ = out_metadata diff --git a/jax/_src/core.py b/jax/_src/core.py index fb3311aff..27e356c14 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -512,7 +512,8 @@ class Trace(Generic[TracerType]): "to handle custom_transpose_call primitives") raise NotImplementedError(msg) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, + out_trees, symbolic_zeros): msg = (f"{type(self)} must override process_custom_vjp_call " "to handle custom_vjp primitives") raise NotImplementedError(msg) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 78f64a798..f2c75473e 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from functools import update_wrapper, reduce, partial import inspect -from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any) +from typing import ( + Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar) from jax._src import core from jax._src import custom_api_util @@ -23,8 +25,8 @@ from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu from jax._src import traceback_util -from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval, - stop_gradient_p) +from jax._src.ad_util import ( + stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax._src.config import config from jax._src.core import raise_to_shaped @@ -162,12 +164,13 @@ class custom_jvp(Generic[ReturnValue]): and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof. symbolic_zeros: boolean, indicating whether the rule should be passed - objects representing static symbolic zeros in its tangent tuple - argument; otherwise, only standard JAX types (e.g. array-likes) are - passed. Setting this option to True allows a JVP rule to detect whether - certain inputs are not involved in differentiation, but at the cost of - needing special handling for these objects (which e.g. can't be passed - into jax.numpy functions). Default False. + objects representing static symbolic zeros in its tangent argument in + correspondence with unperturbed values; otherwise, only standard JAX + types (e.g. array-likes) are passed. Setting this option to ``True`` + allows a JVP rule to detect whether certain inputs are not involved in + differentiation, but at the cost of needing special handling for these + objects (which e.g. can't be passed into jax.numpy functions). Default + ``False``. Returns: None. @@ -479,12 +482,15 @@ class custom_vjp(Generic[ReturnValue]): self.nondiff_argnums = nondiff_argnums self.fwd: Optional[Callable[..., Tuple[ReturnValue, Any]]] = None self.bwd: Optional[Callable[..., Tuple[Any, ...]]] = None + self.symbolic_zeros = False __getattr__ = custom_api_util.forward_attr def defvjp(self, fwd: Callable[..., Tuple[ReturnValue, Any]], - bwd: Callable[..., Tuple[Any, ...]]) -> None: + bwd: Callable[..., Tuple[Any, ...]], + symbolic_zeros: bool = False, + ) -> None: """Define a custom VJP rule for the function represented by this instance. Args: @@ -505,6 +511,38 @@ class custom_vjp(Generic[ReturnValue]): function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments. + symbolic_zeros: boolean, determining whether to indicate symbolic zeros + to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom + derivative rules to detect when certain inputs, and when certain + output cotangents, are not involved in differentiation. If ``True``: + + * ``fwd`` must accept, in place of each leaf value ``x`` in the pytree + comprising an argument to the original function, an object with two + attributes instead: ``value`` and ``perturbed``. The ``value`` field + is the original primal argument, and ``perturbed`` is a boolean. + The ``perturbed`` bit indicates whether the argument is involved in + differentiation (i.e., if it is ``False``, then the corresponding + Jacobian "column" is zero). + + * ``bwd`` will be passed objects representing static symbolic zeros in + its cotangent argument in correspondence with unperturbed values; + otherwise, only standard JAX types (e.g. array-likes) are passed. + + Setting this option to ``True`` allows these rules to detect whether + certain inputs and outputs are not involved in differentiation, but at + the cost of special handling. For instance: + + * The signature of ``fwd`` changes, and the objects it is passed cannot + be output from the rule directly. + + * The ``bwd`` rule is passed objects that are not entirely array-like, + and that cannot be passed to most ``jax.numpy`` functions. + + * Any custom pytree nodes involved in the primal function's arguments + must accept, in their unflattening functions, the two-field record + objects that are given as input leaves to the ``fwd`` rule. + + Default ``False``. Returns: None. @@ -526,6 +564,7 @@ class custom_vjp(Generic[ReturnValue]): """ self.fwd = fwd self.bwd = bwd + self.symbolic_zeros = symbolic_zeros @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation @@ -533,7 +572,7 @@ class custom_vjp(Generic[ReturnValue]): if not self.fwd or not self.bwd: msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) - fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) + fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) args = _resolve_kwargs(self.fun, args, kwargs) if config.jax_enable_custom_vjp_by_custom_transpose: if self.nondiff_argnums: @@ -557,14 +596,38 @@ class custom_vjp(Generic[ReturnValue]): args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) - flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree, - out_type) + flat_fwd, out_trees = _flatten_fwd(fwd, self.symbolic_zeros, primal_name, + fwd_name, in_tree, out_type) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, - *args_flat, out_trees=out_trees) + *args_flat, out_trees=out_trees, + symbolic_zeros=self.symbolic_zeros) _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) return tree_unflatten(out_tree, out_flat) +@dataclasses.dataclass +class CustomVJPPrimal: + value: Any + perturbed: bool + +def custom_vjp_primal_tree_values(tree): + """Strips away perturbation information from forward rule arguments. + + This is a helper function for user with the ``symbolic_zeros`` option to + the ``defvjp`` method of a ``custom_vjp``-decorated function. + + In ``symbolic_zeros`` mode, the custom forward rule receives arguments + whose pytree leaves are records with a ``value`` attribute that carries + the primal argument. This is a way to convert such argument trees back to + their original form, replacing each such record with its carried value at + each leaf. + """ + def value(leaf): + if type(leaf) is not CustomVJPPrimal: + raise TypeError(f"unexpected leaf type {type(leaf)}") + return leaf.value + return tree_map(value, tree) + def _check_for_tracers(x): for leaf in tree_leaves(x): if isinstance(leaf, core.Tracer): @@ -578,7 +641,12 @@ def _check_for_tracers(x): raise UnexpectedTracerError(msg) @lu.transformation_with_aux -def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args): +def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, + *args): + if symbolic_zeros: + args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])] + else: + args = args[::2] py_args = tree_unflatten(in_tree, args) pair_out = yield py_args, {} if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: @@ -671,7 +739,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args): class CustomVJPCallPrimitive(core.CallPrimitive): initial_style: core.Primitive - def bind(self, fun, fwd, bwd, *args, out_trees): + def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros): args = map(core.full_lower, args) top_trace = core.find_top_trace(args) fun, env_trace_todo1 = process_env_traces( @@ -681,7 +749,8 @@ class CustomVJPCallPrimitive(core.CallPrimitive): tracers = map(top_trace.full_raise, args) # type: ignore bwd_ = lambda *args: bwd(*args) outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers, - out_trees=out_trees) + out_trees=out_trees, + symbolic_zeros=symbolic_zeros) fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2) if fst: return core.apply_todos(env_trace_todo, map(core.full_lower, outs)) @@ -747,31 +816,34 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun( def _custom_vjp_call_jaxpr_jvp( primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], - bwd: Callable, out_trees: Callable, num_consts: int): + fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]], + num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): _, args = split_list(primals, [num_consts]) consts_dot, args_dot = split_list(tangents, [num_consts]) if any(type(t) is not Zero for t in consts_dot): raise ad.CustomVJPException() - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers! - out_tree, res_tree = out_trees() + zeros = [type(t) is not Zero for t in args_dot] + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers! + _, res_tree = out_trees() + res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) + res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) # Cast float0 to zeros with the primal dtype because custom vjp rules don't # currently handle float0s args_dot = map(ad.replace_float0s, args, args_dot) - res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] tangents_out = ad.custom_lin_p.bind( - *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out) + *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp -def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, - axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], - bwd: Callable, out_trees: Callable, num_consts: int): +def _custom_vjp_call_jaxpr_vmap( + spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, + fun_jaxpr: core.ClosedJaxpr, + fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]], + num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] @@ -784,8 +856,8 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, out_dims2 = [] @pe._memoize - def batched_fwd_jaxpr_thunk(): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers + def batched_fwd_jaxpr_thunk(*zeros): + fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name, main_type) @@ -794,17 +866,20 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] - batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, - fwd_args_batched, main_type, spmd_axis_name) + batched_bwd = batching.batch_custom_vjp_bwd( + bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type, + spmd_axis_name) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, - out_trees=out_trees, num_consts=num_consts) + num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims -batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap -batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None) +batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ + _custom_vjp_call_jaxpr_vmap +batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial( + _custom_vjp_call_jaxpr_vmap, None) xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d6e6cf2b9..d2b5133e0 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -27,9 +27,9 @@ from jax.tree_util import (tree_flatten, tree_unflatten, from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( - add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval, - zeros_like_p, Zero, replace_internal_symbolic_zeros, - replace_rule_output_symbolic_zeros) + add_jaxvals, add_jaxvals_p, replace_internal_symbolic_zeros, + replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, + zeros_like_jaxval, zeros_like_p) from jax._src.api_util import flatten_fun, flatten_fun_nokwargs from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, raise_to_shaped) @@ -387,16 +387,21 @@ class JVPTrace(Trace): def post_process_custom_jvp_call(self, out_tracers, _): raise CustomJVPException() - def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, + symbolic_zeros): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - tangents_in = map(instantiate_zeros, tangents_in) - res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) - out_tree, res_tree = out_trees() + fwd_in = [(core.full_lower(p), type(t) is not Zero) + for p, t in zip(primals_in, tangents_in)] + fwd_in = [x for pair in fwd_in for x in pair] # flatten + res_and_primals_out = fwd.call_wrapped(*fwd_in) + _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! + tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out) + out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) @@ -745,10 +750,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__): "function.") custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp) -def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals): +def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals, + symbolic_zeros): res, _ = split_list(invals, [num_res]) - cts_out = map(instantiate_zeros_aval, out_avals, cts_out) + if symbolic_zeros: + cts_out = map(replace_internal_symbolic_zeros, cts_out) + else: + cts_out = map(instantiate_zeros_aval, out_avals, cts_out) cts_in = bwd(*res, *cts_out) + cts_in = map(replace_rule_output_symbolic_zeros, cts_in) return [None] * num_res + list(cts_in) primitive_transposes[custom_lin_p] = _custom_lin_transpose diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 978d4cd51..5acbd69df 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -25,17 +25,18 @@ import jax from jax.config import config from jax._src import core from jax._src import source_info_util -from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName -from jax._src.tree_util import (tree_unflatten, tree_flatten, - register_pytree_node) +from jax._src import linear_util as lu from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p, Zero, SymbolicZero, replace_rule_output_symbolic_zeros, instantiate) -from jax._src import linear_util as lu +from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName +from jax._src.interpreters import partial_eval as pe +from jax._src.tree_util import (tree_unflatten, tree_flatten, + register_pytree_node) from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) -from jax._src.interpreters import partial_eval as pe + Array = Any map, unsafe_map = safe_map, map @@ -478,16 +479,19 @@ class BatchTrace(Trace): return map(partial(BatchTracer, trace), vals, dims, srcs) return vals, todo - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): # pytype: disable=signature-mismatch + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, + symbolic_zeros): # pytype: disable=signature-mismatch in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} + fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) - fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) + fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims) bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name) - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) + out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: _, res_tree = out_trees() @@ -784,8 +788,14 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals): def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name): def new_bwd(*args): + in_dims_ = in_dims() if callable(in_dims) else in_dims + args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval)) + if type(x) is SymbolicZero else x + for x, dim in zip(args, in_dims_)] + in_dims_ = [None if type(x) is SymbolicZero else d + for x, d in zip(args, in_dims_)] bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) - bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type, + bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type, spmd_axis_name) bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests) @@ -802,7 +812,7 @@ def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis - if isinstance(x, Zero): + if isinstance(x, (Zero, SymbolicZero)): if src == dst: return x elif type(src) == type(dst) == int: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 587bef00b..8aecd4eef 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -497,14 +497,16 @@ class JaxprTrace(Trace['JaxprTracer']): for t in out_tracers: t.recipe = eqn return out_tracers - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, + symbolic_zeros): # TODO(mattjj): after old remat is deleted, make this method trivial. # Because we instantiate all tracers, in_knowns is all False. tracers = map(self.instantiate_const_abstracted, tracers) in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) f = trace_to_subjaxpr_nounits(f, self.main, True) f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees) + out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) out_knowns, out_avals, jaxpr, env = aux() out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) res_tracers = map(self.new_instantiated_const, res) @@ -514,8 +516,9 @@ class JaxprTrace(Trace['JaxprTracer']): closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) @_memoize - def fwd_jaxpr_thunk(): - fwd_ = trace_to_subjaxpr_nounits(fwd, self.main, True) + def fwd_jaxpr_thunk(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + fwd_ = trace_to_subjaxpr_nounits(fwd_, self.main, True) fwd_, aux = partial_eval_wrapper_nounits( fwd_, tuple(in_knowns), tuple(in_avals)) with core.new_sublevel(): @@ -532,7 +535,8 @@ class JaxprTrace(Trace['JaxprTracer']): dict(fun_jaxpr=closed_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, num_consts=len(res) + len(env), - bwd=bwd, out_trees=out_trees), + bwd=bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros), jaxpr.effects, source) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -1959,23 +1963,29 @@ class DynamicJaxprTrace(core.Trace): def post_process_custom_jvp_call(self, out_tracers, _): assert False # unreachable - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): in_avals = [t.aval for t in tracers] with core.new_sublevel(): fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + main_ = ref(self.main) - fwd_jaxpr_thunk = _memoize( - lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2]) + @_memoize + def fwd_jaxpr_from_zeros(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + return trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals)[::2] + out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) constvars = map(self.getvar, map(self.instantiate_const, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, dict(fun_jaxpr=closed_fun_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_thunk, + fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, num_consts=len(consts), - bwd=bwd, out_trees=out_trees), + bwd=bwd, out_trees=out_trees, + symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, source_info_util.current()) self.frame.add_eqn(eqn) @@ -2024,18 +2034,23 @@ class DynamicJaxprTrace(core.Trace): custom_staging_rules: Dict[Primitive, Callable] = {} -def _memoize(thunk): - cell = [] +@lu.transformation +def _interleave_fun(every_others, *args, **kwargs): + args_ = [x for pair in zip(args, every_others) for x in pair] + yield (yield (args_, kwargs)) + +def _memoize(fn): + cells = {} saved_state = [core.thread_local_state.trace_state.copy()] - def memoized(): - if not cell: + def memoized(*args): + if args not in cells: prev_state = core.thread_local_state.trace_state - core.thread_local_state.trace_state = saved_state.pop() + core.thread_local_state.trace_state = saved_state[0] try: - cell.append(thunk()) + cells[args] = fn(*args) finally: core.thread_local_state.trace_state = prev_state - return cell[0] + return cells[args] return memoized # TODO(mattjj): remove this DebugInfo and helper functions, replace with diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 666a2b22e..81bd753d2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -918,10 +918,11 @@ class MapTrace(core.Trace): return self.process_primitive(fake_primitive, tracers, {}) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, - out_trees): + out_trees, symbolic_zeros): bind = HashableFunction( - lambda *args, **kwargs: primitive.bind(fun, fwd, bwd, *args, - out_trees=out_trees, **kwargs), + lambda *args, **kwargs: primitive.bind( + fun, fwd, bwd, *args, out_trees=out_trees, + symbolic_zeros=symbolic_zeros, **kwargs), (primitive, fun, fwd, bwd)) fake_primitive = FakePrimitive(multiple_results=True, bind=bind) return self.process_primitive(fake_primitive, tracers, {}) diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index c5a17d1bb..f6b8ff9e9 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -27,6 +27,8 @@ from jax._src.custom_derivatives import ( custom_vjp as custom_vjp, custom_vjp_call_p as custom_vjp_call_p, custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, + custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, + CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, ) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 11824a90b..b148b3ca2 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1254,11 +1254,12 @@ class TensorFlowTrace(core.Trace): def post_process_custom_jvp_call(self, out_tracers, _): assert False # unreachable assuming jax2tf runs with clean trace state - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): # Drop the custom differentiation rule and act like a call primitive. This # behavior is desirable because jax2tf stages code out of the JAX system, so # there are no more JAX differentiation transformations to be applied. - del fwd, bwd, out_trees # Unused. + del fwd, bwd, out_trees, symbolic_zeros # Unused. return self.process_call(core.call_p, fun, tracers, {}) def post_process_custom_vjp_call(self, out_tracers, _): diff --git a/tests/api_test.py b/tests/api_test.py index ee736c4b1..c8d19c985 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7415,10 +7415,10 @@ class CustomJVPTest(jtu.JaxTestCase): primal_outs, tangent_outs = run(primal_ins, tangent_ins) primal_out1, primal_out2 = primal_outs tangent_out1, tangent_out2 = tangent_outs - scalar_dtype = jax.Array if maybe_jit or maybe_vmap else float - self.assertIsInstance(primal_out1, scalar_dtype) + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + self.assertIsInstance(primal_out1, scalar_type) self.assertAllClose(primal_out1, 5.) - self.assertIsInstance(tangent_out1, scalar_dtype) + self.assertIsInstance(tangent_out1, scalar_type) self.assertAllClose(tangent_out1, 91.) self.assertIsInstance(primal_out2, jax.Array) self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) @@ -8496,6 +8496,245 @@ class CustomVJPTest(jtu.JaxTestCase): f_vjp(jnp.array([3.])) f_vjp(jnp.array([3.])) # doesn't crash + def test_symbolic_zero_custom_vjp_basic(self): + ZERO = custom_derivatives_public.SymbolicZero + + @jax.custom_vjp + def f(x, y, z): + return x, x + + def fwd(x, y, z): + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + self.assertFalse(z.perturbed) + return (x.value, x.value), None + + def fwd_all(x, y, z): + self.assertTrue(x.perturbed) + self.assertTrue(y.perturbed) + self.assertTrue(z.perturbed) + return (x.value, x.value), None + + def bwd_all(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + def bwd_fst(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertIs(type(x2), ZERO) + return x1, x1, x2 + + def bwd_snd(_, g): + x1, x2 = g + self.assertIs(type(x1), ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + x, y, z = 4., 5., 6. + i = np.array(7, np.int32) + zero = np.array(0.) + + f.defvjp(fwd, bwd_all, symbolic_zeros=True) + h = jax.jit(f) + jax.jacrev(h)(x, y, z) + jax.jacrev(lambda x: h(x, y, z))(x) + jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) + + f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) + fst_f = lambda *xs: f(*xs)[0] + _, vjp = jax.vjp(fst_f, x, y, z) + _, _, gz = vjp(x) + self.assertArraysAllClose(gz, zero) + + f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) + snd_f = lambda *xs: f(*xs)[1] + _, vjp = jax.vjp(snd_f, x, y, z) + gx, gy, _ = vjp(x) + self.assertArraysAllClose(gx, zero) + self.assertArraysAllClose(gy, zero) + + f.defvjp(fwd, bwd_snd, symbolic_zeros=True) + _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) + gx, = vjp(x) + self.assertArraysAllClose(gx, zero) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): + # below: + # * static_scalar will be static in and out + # * static_array will be static in, but dynamic out + # * dyn_scalar and dyn_array will be dynamic in and out + + ZERO = custom_derivatives_public.SymbolicZero + + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return static_scalar, static_array, out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = tree_util.tree_map(_pack, args) + out = jax.vmap(fun)(*args) + out = tree_util.tree_map(_unpack, out) + return out + return _fun + + f = jax.custom_vjp(f) + + def fwd(*args): + xs, pert = [x.value for x in args], [x.perturbed for x in args] + self.assertFalse(pert[0]) + self.assertFalse(pert[1]) + self.assertTrue(pert[2]) + self.assertTrue(pert[3]) + return f(*xs), xs + + def bwd(res, g): + static_scalar, *_ = res + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g + self.assertIs(type(t_static), ZERO) + self.assertFalse(type(t_static_arr) is ZERO) + self.assertFalse(type(t_dyn_scalar) is ZERO) + self.assertFalse(type(t_dyn_array) is ZERO) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return (static_scalar + 90, + t_static_arr + 91, + t_dyn_scalar + 92, + t_dyn_array + 93) + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + return outs[1:] + + def run(primal_ins, cotangent_outs): + primal_outs, vjp = jax.vjp(g, *primal_ins) + cotangent_ins = vjp(cotangent_outs) + return primal_outs, cotangent_ins + + if maybe_jit: + run = jax.jit(run) + + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + primal_ins = (4., jnp.array([5., 6.])) + cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) + primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) + + primal_out1, primal_out2, primal_out3 = primal_outs + self.assertIsInstance(primal_out1, jax.Array) + self.assertAllClose(primal_out1, jnp.array([2., 3.])) + self.assertIsInstance(primal_out2, scalar_type) + self.assertAllClose(primal_out2, 5.) + self.assertIsInstance(primal_out3, jax.Array) + self.assertAllClose(primal_out3, jnp.array([7., 9.])) + + ct_in1, ct_in2 = cotangent_ins + self.assertIsInstance(ct_in1, scalar_type) + self.assertAllClose(ct_in1, 99.) + self.assertIsInstance(ct_in2, jax.Array) + self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) + + def test_symbolic_zero_custom_vjp_vmap_output(self): + @jax.custom_vjp + def f(x, y): + return x, y + + def fwd(x, y): + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + return f(x.value, y.value), None + + def bwd(_, g): + _, ct_y = g + self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero) + return g + + f.defvjp(fwd, bwd, symbolic_zeros=True) + jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) + + def test_symbolic_zero_custom_vjp_custom_pytree(self): + tree_values = custom_derivatives_public.custom_vjp_primal_tree_values + + @tree_util.register_pytree_node_class + class Box: + def __init__(self_, strict, val): + if strict: + # make sure we aren't getting special arguments that should only + # come up when symbolic_zeros is True + self.assertFalse(hasattr(val, 'perturbed')) + self_.strict = strict + self_.x = val + + def tree_flatten(self_): + return [self_.x], self_.strict + + @classmethod + def tree_unflatten(cls, strict, xs): + x, = xs + return cls(strict, x) + + x, y = Box(False, jnp.array(72.)), jnp.array(73.) + + @jax.custom_vjp + def f(box, y): + return box.x * y + + def fwd0(box, y): + self.assertTrue(box.x.perturbed) + self.assertFalse(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd0(res, g): + box, y = res + return y * g, box.x * g + + def fwd1(box, y): + self.assertFalse(box.x.perturbed) + self.assertTrue(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd1(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd0, bwd0, symbolic_zeros=True) + jax.grad(f, argnums=0)(x, y) + f.defvjp(fwd1, bwd1, symbolic_zeros=True) + jax.grad(f, argnums=1)(x, y) + + def fwd_strict(box, y): + return f(box, y), (box, y) + + def bwd_strict(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd_strict, bwd_strict) + jax.grad(f)(x, y) def transpose_unary(f, x_example): def transposed(y):