mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
custom_vjp
symbolic zeros support, take two
This change re-introduces symbolic zero support for `custom_vjp`. This time: * The forward rule API is slightly different, accepting two-field records at pytree leaves rather than pairs. * In the default setting where symbolic_zeros is not set, there are no new requirements from pytree node definitions that are involved in the primal arguments. This avoids any change in behavior on the default path. In particular, custom pytree node definitions that aren't completely polymorphic in unflattening can remain as is. * There is an additional test involving a custom pytree node.
This commit is contained in:
parent
0e549ac4be
commit
d51b8e6839
@ -932,7 +932,8 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
|
||||
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule
|
||||
|
||||
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
|
||||
fwd_jaxpr_thunk, num_consts, bwd, out_trees):
|
||||
fwd_jaxpr_thunk, num_consts, bwd, out_trees,
|
||||
symbolic_zeros):
|
||||
err_vals, err_tree = jtu.tree_flatten(in_err)
|
||||
fun = lu.wrap_init(
|
||||
functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
|
||||
@ -940,15 +941,17 @@ def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
|
||||
fun, fun_metadata = _flatten_and_get_error_metadata_thunk(fun)
|
||||
|
||||
@lu.wrap_init
|
||||
def fwd(*xs):
|
||||
def fwd(*args):
|
||||
# TODO(lenamartens, sharadmv): why not checkify here?
|
||||
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk()
|
||||
xs, zeros = args[::2], args[1::2]
|
||||
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
|
||||
xs_without_consts = xs[num_consts:]
|
||||
return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)
|
||||
|
||||
fwd, fwd_out_tree = flatten_fun_output(fwd)
|
||||
all_outs = custom_derivatives.custom_vjp_call_p.bind(
|
||||
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees)
|
||||
fun, fwd, bwd, *err_vals, *in_vals, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros)
|
||||
fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
|
||||
if fst:
|
||||
err_and_out_tree, _ = out_metadata
|
||||
|
@ -512,7 +512,8 @@ class Trace(Generic[TracerType]):
|
||||
"to handle custom_transpose_call primitives")
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
|
||||
out_trees, symbolic_zeros):
|
||||
msg = (f"{type(self)} must override process_custom_vjp_call "
|
||||
"to handle custom_vjp primitives")
|
||||
raise NotImplementedError(msg)
|
||||
|
@ -12,9 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
from functools import update_wrapper, reduce, partial
|
||||
import inspect
|
||||
from typing import (Callable, Generic, List, Optional, Sequence, Tuple, TypeVar, Any)
|
||||
from typing import (
|
||||
Any, Callable, Generic, List, Optional, Sequence, Tuple, TypeVar)
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
@ -23,8 +25,8 @@ from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import traceback_util
|
||||
from jax._src.ad_util import (Zero, SymbolicZero, zeros_like_aval,
|
||||
stop_gradient_p)
|
||||
from jax._src.ad_util import (
|
||||
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
||||
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
||||
from jax._src.config import config
|
||||
from jax._src.core import raise_to_shaped
|
||||
@ -162,12 +164,13 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
and the second element is the tangent output. Elements of the input and
|
||||
output tuples may be arrays or any nested tuples/lists/dicts thereof.
|
||||
symbolic_zeros: boolean, indicating whether the rule should be passed
|
||||
objects representing static symbolic zeros in its tangent tuple
|
||||
argument; otherwise, only standard JAX types (e.g. array-likes) are
|
||||
passed. Setting this option to True allows a JVP rule to detect whether
|
||||
certain inputs are not involved in differentiation, but at the cost of
|
||||
needing special handling for these objects (which e.g. can't be passed
|
||||
into jax.numpy functions). Default False.
|
||||
objects representing static symbolic zeros in its tangent argument in
|
||||
correspondence with unperturbed values; otherwise, only standard JAX
|
||||
types (e.g. array-likes) are passed. Setting this option to ``True``
|
||||
allows a JVP rule to detect whether certain inputs are not involved in
|
||||
differentiation, but at the cost of needing special handling for these
|
||||
objects (which e.g. can't be passed into jax.numpy functions). Default
|
||||
``False``.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
@ -479,12 +482,15 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
self.nondiff_argnums = nondiff_argnums
|
||||
self.fwd: Optional[Callable[..., Tuple[ReturnValue, Any]]] = None
|
||||
self.bwd: Optional[Callable[..., Tuple[Any, ...]]] = None
|
||||
self.symbolic_zeros = False
|
||||
|
||||
__getattr__ = custom_api_util.forward_attr
|
||||
|
||||
def defvjp(self,
|
||||
fwd: Callable[..., Tuple[ReturnValue, Any]],
|
||||
bwd: Callable[..., Tuple[Any, ...]]) -> None:
|
||||
bwd: Callable[..., Tuple[Any, ...]],
|
||||
symbolic_zeros: bool = False,
|
||||
) -> None:
|
||||
"""Define a custom VJP rule for the function represented by this instance.
|
||||
|
||||
Args:
|
||||
@ -505,6 +511,38 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
function, and the tuple elements may be arrays or nested
|
||||
tuples/lists/dicts thereof so as to match the structure of the primal
|
||||
input arguments.
|
||||
symbolic_zeros: boolean, determining whether to indicate symbolic zeros
|
||||
to the ``fwd`` and ``bwd`` rules. Enabling this option allows custom
|
||||
derivative rules to detect when certain inputs, and when certain
|
||||
output cotangents, are not involved in differentiation. If ``True``:
|
||||
|
||||
* ``fwd`` must accept, in place of each leaf value ``x`` in the pytree
|
||||
comprising an argument to the original function, an object with two
|
||||
attributes instead: ``value`` and ``perturbed``. The ``value`` field
|
||||
is the original primal argument, and ``perturbed`` is a boolean.
|
||||
The ``perturbed`` bit indicates whether the argument is involved in
|
||||
differentiation (i.e., if it is ``False``, then the corresponding
|
||||
Jacobian "column" is zero).
|
||||
|
||||
* ``bwd`` will be passed objects representing static symbolic zeros in
|
||||
its cotangent argument in correspondence with unperturbed values;
|
||||
otherwise, only standard JAX types (e.g. array-likes) are passed.
|
||||
|
||||
Setting this option to ``True`` allows these rules to detect whether
|
||||
certain inputs and outputs are not involved in differentiation, but at
|
||||
the cost of special handling. For instance:
|
||||
|
||||
* The signature of ``fwd`` changes, and the objects it is passed cannot
|
||||
be output from the rule directly.
|
||||
|
||||
* The ``bwd`` rule is passed objects that are not entirely array-like,
|
||||
and that cannot be passed to most ``jax.numpy`` functions.
|
||||
|
||||
* Any custom pytree nodes involved in the primal function's arguments
|
||||
must accept, in their unflattening functions, the two-field record
|
||||
objects that are given as input leaves to the ``fwd`` rule.
|
||||
|
||||
Default ``False``.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
@ -526,6 +564,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
"""
|
||||
self.fwd = fwd
|
||||
self.bwd = bwd
|
||||
self.symbolic_zeros = symbolic_zeros
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
|
||||
@ -533,7 +572,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
if not self.fwd or not self.bwd:
|
||||
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
|
||||
raise AttributeError(msg)
|
||||
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
|
||||
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
if config.jax_enable_custom_vjp_by_custom_transpose:
|
||||
if self.nondiff_argnums:
|
||||
@ -557,14 +596,38 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
|
||||
out_type)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd, self.symbolic_zeros, primal_name,
|
||||
fwd_name, in_tree, out_type)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees)
|
||||
*args_flat, out_trees=out_trees,
|
||||
symbolic_zeros=self.symbolic_zeros)
|
||||
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomVJPPrimal:
|
||||
value: Any
|
||||
perturbed: bool
|
||||
|
||||
def custom_vjp_primal_tree_values(tree):
|
||||
"""Strips away perturbation information from forward rule arguments.
|
||||
|
||||
This is a helper function for user with the ``symbolic_zeros`` option to
|
||||
the ``defvjp`` method of a ``custom_vjp``-decorated function.
|
||||
|
||||
In ``symbolic_zeros`` mode, the custom forward rule receives arguments
|
||||
whose pytree leaves are records with a ``value`` attribute that carries
|
||||
the primal argument. This is a way to convert such argument trees back to
|
||||
their original form, replacing each such record with its carried value at
|
||||
each leaf.
|
||||
"""
|
||||
def value(leaf):
|
||||
if type(leaf) is not CustomVJPPrimal:
|
||||
raise TypeError(f"unexpected leaf type {type(leaf)}")
|
||||
return leaf.value
|
||||
return tree_map(value, tree)
|
||||
|
||||
def _check_for_tracers(x):
|
||||
for leaf in tree_leaves(x):
|
||||
if isinstance(leaf, core.Tracer):
|
||||
@ -578,7 +641,12 @@ def _check_for_tracers(x):
|
||||
raise UnexpectedTracerError(msg)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args):
|
||||
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
|
||||
*args):
|
||||
if symbolic_zeros:
|
||||
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
|
||||
else:
|
||||
args = args[::2]
|
||||
py_args = tree_unflatten(in_tree, args)
|
||||
pair_out = yield py_args, {}
|
||||
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
||||
@ -671,7 +739,7 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
initial_style: core.Primitive
|
||||
|
||||
def bind(self, fun, fwd, bwd, *args, out_trees):
|
||||
def bind(self, fun, fwd, bwd, *args, out_trees, symbolic_zeros):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
fun, env_trace_todo1 = process_env_traces(
|
||||
@ -681,7 +749,8 @@ class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
bwd_ = lambda *args: bwd(*args)
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
|
||||
out_trees=out_trees)
|
||||
out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros)
|
||||
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
if fst:
|
||||
return core.apply_todos(env_trace_todo, map(core.full_lower, outs))
|
||||
@ -747,31 +816,34 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
|
||||
|
||||
def _custom_vjp_call_jaxpr_jvp(
|
||||
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
bwd: Callable, out_trees: Callable, num_consts: int):
|
||||
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
||||
_, args = split_list(primals, [num_consts])
|
||||
consts_dot, args_dot = split_list(tangents, [num_consts])
|
||||
if any(type(t) is not Zero for t in consts_dot):
|
||||
raise ad.CustomVJPException()
|
||||
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers!
|
||||
out_tree, res_tree = out_trees()
|
||||
zeros = [type(t) is not Zero for t in args_dot]
|
||||
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers!
|
||||
_, res_tree = out_trees()
|
||||
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
args_dot = map(ad.instantiate_zeros, args_dot)
|
||||
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
|
||||
# currently handle float0s
|
||||
args_dot = map(ad.replace_float0s, args, args_dot)
|
||||
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
tangents_out = ad.custom_lin_p.bind(
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out)
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
|
||||
return primals_out, tangents_out
|
||||
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
|
||||
def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
|
||||
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
bwd: Callable, out_trees: Callable, num_consts: int):
|
||||
def _custom_vjp_call_jaxpr_vmap(
|
||||
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
|
||||
fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[..., Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
|
||||
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||
else x for x, d in zip(args, in_dims)]
|
||||
|
||||
@ -784,8 +856,8 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
|
||||
out_dims2 = []
|
||||
|
||||
@pe._memoize
|
||||
def batched_fwd_jaxpr_thunk():
|
||||
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
|
||||
def batched_fwd_jaxpr_thunk(*zeros):
|
||||
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
|
||||
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
|
||||
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name,
|
||||
main_type)
|
||||
@ -794,17 +866,20 @@ def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
|
||||
|
||||
fwd_args_batched = [0 if b else not_mapped for b in args_batched]
|
||||
fwd_out_dims = lambda: out_dims2[0]
|
||||
batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
|
||||
fwd_args_batched, main_type, spmd_axis_name)
|
||||
batched_bwd = batching.batch_custom_vjp_bwd(
|
||||
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type,
|
||||
spmd_axis_name)
|
||||
|
||||
batched_outs = custom_vjp_call_jaxpr_p.bind(
|
||||
*args, fun_jaxpr=batched_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
|
||||
out_trees=out_trees, num_consts=num_consts)
|
||||
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
|
||||
out_dims = out_dims2[0] if out_dims2 else out_dims1
|
||||
return batched_outs, out_dims
|
||||
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
|
||||
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None)
|
||||
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
|
||||
_custom_vjp_call_jaxpr_vmap
|
||||
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
|
||||
_custom_vjp_call_jaxpr_vmap, None)
|
||||
|
||||
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
|
||||
|
||||
|
@ -27,9 +27,9 @@ from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src.ad_util import (
|
||||
add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
||||
zeros_like_p, Zero, replace_internal_symbolic_zeros,
|
||||
replace_rule_output_symbolic_zeros)
|
||||
add_jaxvals, add_jaxvals_p, replace_internal_symbolic_zeros,
|
||||
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval,
|
||||
zeros_like_jaxval, zeros_like_p)
|
||||
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
@ -387,16 +387,21 @@ class JVPTrace(Trace):
|
||||
def post_process_custom_jvp_call(self, out_tracers, _):
|
||||
raise CustomJVPException()
|
||||
|
||||
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees):
|
||||
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
|
||||
out_tree, res_tree = out_trees()
|
||||
fwd_in = [(core.full_lower(p), type(t) is not Zero)
|
||||
for p, t in zip(primals_in, tangents_in)]
|
||||
fwd_in = [x for pair in fwd_in for x in pair] # flatten
|
||||
res_and_primals_out = fwd.call_wrapped(*fwd_in)
|
||||
_, res_tree = out_trees()
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out)
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
||||
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
||||
|
||||
@ -745,10 +750,15 @@ def raise_custom_vjp_error_on_jvp(*_, **__):
|
||||
"function.")
|
||||
custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)
|
||||
|
||||
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
|
||||
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals,
|
||||
symbolic_zeros):
|
||||
res, _ = split_list(invals, [num_res])
|
||||
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
|
||||
if symbolic_zeros:
|
||||
cts_out = map(replace_internal_symbolic_zeros, cts_out)
|
||||
else:
|
||||
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
|
||||
cts_in = bwd(*res, *cts_out)
|
||||
cts_in = map(replace_rule_output_symbolic_zeros, cts_in)
|
||||
return [None] * num_res + list(cts_in)
|
||||
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
||||
|
||||
|
@ -25,17 +25,18 @@ import jax
|
||||
from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, Zero, SymbolicZero,
|
||||
replace_rule_output_symbolic_zeros, instantiate)
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.core import raise_to_shaped, Trace, Tracer, AxisName
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize, weakref_lru_cache)
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
|
||||
Array = Any
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -478,16 +479,19 @@ class BatchTrace(Trace):
|
||||
return map(partial(BatchTracer, trace), vals, dims, srcs)
|
||||
return vals, todo
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): # pytype: disable=signature-mismatch
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees,
|
||||
symbolic_zeros): # pytype: disable=signature-mismatch
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
|
||||
if d is not not_mapped}
|
||||
fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]]
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims)
|
||||
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
|
||||
out_dims2, in_dims, self.main.trace_type,
|
||||
self.spmd_axis_name)
|
||||
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
||||
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros)
|
||||
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
||||
if not fst:
|
||||
_, res_tree = out_trees()
|
||||
@ -784,8 +788,14 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
|
||||
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests,
|
||||
main_type, spmd_axis_name):
|
||||
def new_bwd(*args):
|
||||
in_dims_ = in_dims() if callable(in_dims) else in_dims
|
||||
args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval))
|
||||
if type(x) is SymbolicZero else x
|
||||
for x, dim in zip(args, in_dims_)]
|
||||
in_dims_ = [None if type(x) is SymbolicZero else d
|
||||
for x, d in zip(args, in_dims_)]
|
||||
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd))
|
||||
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type,
|
||||
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type,
|
||||
spmd_axis_name)
|
||||
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
|
||||
out_dim_dests)
|
||||
@ -802,7 +812,7 @@ def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in
|
||||
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
|
||||
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
|
||||
# TODO(mattjj): dedup with matchaxis
|
||||
if isinstance(x, Zero):
|
||||
if isinstance(x, (Zero, SymbolicZero)):
|
||||
if src == dst:
|
||||
return x
|
||||
elif type(src) == type(dst) == int:
|
||||
|
@ -497,14 +497,16 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees):
|
||||
def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
# TODO(mattjj): after old remat is deleted, make this method trivial.
|
||||
# Because we instantiate all tracers, in_knowns is all False.
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers])
|
||||
f = trace_to_subjaxpr_nounits(f, self.main, True)
|
||||
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
|
||||
out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees)
|
||||
out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros)
|
||||
out_knowns, out_avals, jaxpr, env = aux()
|
||||
out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
res_tracers = map(self.new_instantiated_const, res)
|
||||
@ -514,8 +516,9 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
@_memoize
|
||||
def fwd_jaxpr_thunk():
|
||||
fwd_ = trace_to_subjaxpr_nounits(fwd, self.main, True)
|
||||
def fwd_jaxpr_thunk(*zeros):
|
||||
fwd_ = _interleave_fun(fwd, zeros)
|
||||
fwd_ = trace_to_subjaxpr_nounits(fwd_, self.main, True)
|
||||
fwd_, aux = partial_eval_wrapper_nounits(
|
||||
fwd_, tuple(in_knowns), tuple(in_avals))
|
||||
with core.new_sublevel():
|
||||
@ -532,7 +535,8 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
dict(fun_jaxpr=closed_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
||||
num_consts=len(res) + len(env),
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
bwd=bwd, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros),
|
||||
jaxpr.effects, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
@ -1959,23 +1963,29 @@ class DynamicJaxprTrace(core.Trace):
|
||||
def post_process_custom_jvp_call(self, out_tracers, _):
|
||||
assert False # unreachable
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
with core.new_sublevel():
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
|
||||
main_ = ref(self.main)
|
||||
fwd_jaxpr_thunk = _memoize(
|
||||
lambda: trace_to_subjaxpr_dynamic(fwd, main_(), in_avals)[::2])
|
||||
@_memoize
|
||||
def fwd_jaxpr_from_zeros(*zeros):
|
||||
fwd_ = _interleave_fun(fwd, zeros)
|
||||
return trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals)[::2]
|
||||
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
outvars = map(self.makevar, out_tracers)
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_from_zeros,
|
||||
num_consts=len(consts),
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
bwd=bwd, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros),
|
||||
fun_jaxpr.effects,
|
||||
source_info_util.current())
|
||||
self.frame.add_eqn(eqn)
|
||||
@ -2024,18 +2034,23 @@ class DynamicJaxprTrace(core.Trace):
|
||||
|
||||
custom_staging_rules: Dict[Primitive, Callable] = {}
|
||||
|
||||
def _memoize(thunk):
|
||||
cell = []
|
||||
@lu.transformation
|
||||
def _interleave_fun(every_others, *args, **kwargs):
|
||||
args_ = [x for pair in zip(args, every_others) for x in pair]
|
||||
yield (yield (args_, kwargs))
|
||||
|
||||
def _memoize(fn):
|
||||
cells = {}
|
||||
saved_state = [core.thread_local_state.trace_state.copy()]
|
||||
def memoized():
|
||||
if not cell:
|
||||
def memoized(*args):
|
||||
if args not in cells:
|
||||
prev_state = core.thread_local_state.trace_state
|
||||
core.thread_local_state.trace_state = saved_state.pop()
|
||||
core.thread_local_state.trace_state = saved_state[0]
|
||||
try:
|
||||
cell.append(thunk())
|
||||
cells[args] = fn(*args)
|
||||
finally:
|
||||
core.thread_local_state.trace_state = prev_state
|
||||
return cell[0]
|
||||
return cells[args]
|
||||
return memoized
|
||||
|
||||
# TODO(mattjj): remove this DebugInfo and helper functions, replace with
|
||||
|
@ -918,10 +918,11 @@ class MapTrace(core.Trace):
|
||||
return self.process_primitive(fake_primitive, tracers, {})
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
|
||||
out_trees):
|
||||
out_trees, symbolic_zeros):
|
||||
bind = HashableFunction(
|
||||
lambda *args, **kwargs: primitive.bind(fun, fwd, bwd, *args,
|
||||
out_trees=out_trees, **kwargs),
|
||||
lambda *args, **kwargs: primitive.bind(
|
||||
fun, fwd, bwd, *args, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros, **kwargs),
|
||||
(primitive, fun, fwd, bwd))
|
||||
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
||||
return self.process_primitive(fake_primitive, tracers, {})
|
||||
|
@ -27,6 +27,8 @@ from jax._src.custom_derivatives import (
|
||||
custom_vjp as custom_vjp,
|
||||
custom_vjp_call_p as custom_vjp_call_p,
|
||||
custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p,
|
||||
custom_vjp_primal_tree_values as custom_vjp_primal_tree_values,
|
||||
CustomVJPPrimal as CustomVJPPrimal,
|
||||
linear_call as linear_call,
|
||||
)
|
||||
|
||||
|
@ -1254,11 +1254,12 @@ class TensorFlowTrace(core.Trace):
|
||||
def post_process_custom_jvp_call(self, out_tracers, _):
|
||||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
# Drop the custom differentiation rule and act like a call primitive. This
|
||||
# behavior is desirable because jax2tf stages code out of the JAX system, so
|
||||
# there are no more JAX differentiation transformations to be applied.
|
||||
del fwd, bwd, out_trees # Unused.
|
||||
del fwd, bwd, out_trees, symbolic_zeros # Unused.
|
||||
return self.process_call(core.call_p, fun, tracers, {})
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||||
|
@ -7415,10 +7415,10 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
primal_outs, tangent_outs = run(primal_ins, tangent_ins)
|
||||
primal_out1, primal_out2 = primal_outs
|
||||
tangent_out1, tangent_out2 = tangent_outs
|
||||
scalar_dtype = jax.Array if maybe_jit or maybe_vmap else float
|
||||
self.assertIsInstance(primal_out1, scalar_dtype)
|
||||
scalar_type = jax.Array if maybe_jit or maybe_vmap else float
|
||||
self.assertIsInstance(primal_out1, scalar_type)
|
||||
self.assertAllClose(primal_out1, 5.)
|
||||
self.assertIsInstance(tangent_out1, scalar_dtype)
|
||||
self.assertIsInstance(tangent_out1, scalar_type)
|
||||
self.assertAllClose(tangent_out1, 91.)
|
||||
self.assertIsInstance(primal_out2, jax.Array)
|
||||
self.assertArraysAllClose(primal_out2, jnp.array([7., 9.]))
|
||||
@ -8496,6 +8496,245 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
f_vjp(jnp.array([3.]))
|
||||
f_vjp(jnp.array([3.])) # doesn't crash
|
||||
|
||||
def test_symbolic_zero_custom_vjp_basic(self):
|
||||
ZERO = custom_derivatives_public.SymbolicZero
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x, y, z):
|
||||
return x, x
|
||||
|
||||
def fwd(x, y, z):
|
||||
self.assertTrue(x.perturbed)
|
||||
self.assertFalse(y.perturbed)
|
||||
self.assertFalse(z.perturbed)
|
||||
return (x.value, x.value), None
|
||||
|
||||
def fwd_all(x, y, z):
|
||||
self.assertTrue(x.perturbed)
|
||||
self.assertTrue(y.perturbed)
|
||||
self.assertTrue(z.perturbed)
|
||||
return (x.value, x.value), None
|
||||
|
||||
def bwd_all(_, g):
|
||||
x1, x2 = g
|
||||
self.assertFalse(type(x1) is ZERO)
|
||||
self.assertFalse(type(x2) is ZERO)
|
||||
return x1, x1, x2
|
||||
|
||||
def bwd_fst(_, g):
|
||||
x1, x2 = g
|
||||
self.assertFalse(type(x1) is ZERO)
|
||||
self.assertIs(type(x2), ZERO)
|
||||
return x1, x1, x2
|
||||
|
||||
def bwd_snd(_, g):
|
||||
x1, x2 = g
|
||||
self.assertIs(type(x1), ZERO)
|
||||
self.assertFalse(type(x2) is ZERO)
|
||||
return x1, x1, x2
|
||||
|
||||
x, y, z = 4., 5., 6.
|
||||
i = np.array(7, np.int32)
|
||||
zero = np.array(0.)
|
||||
|
||||
f.defvjp(fwd, bwd_all, symbolic_zeros=True)
|
||||
h = jax.jit(f)
|
||||
jax.jacrev(h)(x, y, z)
|
||||
jax.jacrev(lambda x: h(x, y, z))(x)
|
||||
jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i)
|
||||
|
||||
f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True)
|
||||
fst_f = lambda *xs: f(*xs)[0]
|
||||
_, vjp = jax.vjp(fst_f, x, y, z)
|
||||
_, _, gz = vjp(x)
|
||||
self.assertArraysAllClose(gz, zero)
|
||||
|
||||
f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True)
|
||||
snd_f = lambda *xs: f(*xs)[1]
|
||||
_, vjp = jax.vjp(snd_f, x, y, z)
|
||||
gx, gy, _ = vjp(x)
|
||||
self.assertArraysAllClose(gx, zero)
|
||||
self.assertArraysAllClose(gy, zero)
|
||||
|
||||
f.defvjp(fwd, bwd_snd, symbolic_zeros=True)
|
||||
_, vjp = jax.vjp(lambda x: snd_f(x, y, z), x)
|
||||
gx, = vjp(x)
|
||||
self.assertArraysAllClose(gx, zero)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('jit_vmap', True, True),
|
||||
('jit', True, False),
|
||||
('vmap', False, True),
|
||||
('', False, False),
|
||||
)
|
||||
def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap):
|
||||
# below:
|
||||
# * static_scalar will be static in and out
|
||||
# * static_array will be static in, but dynamic out
|
||||
# * dyn_scalar and dyn_array will be dynamic in and out
|
||||
|
||||
ZERO = custom_derivatives_public.SymbolicZero
|
||||
|
||||
def f(static_scalar, static_array, dyn_scalar, dyn_array):
|
||||
out1 = static_scalar + dyn_scalar
|
||||
out2 = static_array + dyn_array
|
||||
return static_scalar, static_array, out1, out2
|
||||
|
||||
def _pack(x):
|
||||
return lax.broadcast(x, (1,))
|
||||
|
||||
def _unpack(x):
|
||||
(x,) = x
|
||||
return x
|
||||
|
||||
def _vmap(fun):
|
||||
def _fun(*args):
|
||||
args = tree_util.tree_map(_pack, args)
|
||||
out = jax.vmap(fun)(*args)
|
||||
out = tree_util.tree_map(_unpack, out)
|
||||
return out
|
||||
return _fun
|
||||
|
||||
f = jax.custom_vjp(f)
|
||||
|
||||
def fwd(*args):
|
||||
xs, pert = [x.value for x in args], [x.perturbed for x in args]
|
||||
self.assertFalse(pert[0])
|
||||
self.assertFalse(pert[1])
|
||||
self.assertTrue(pert[2])
|
||||
self.assertTrue(pert[3])
|
||||
return f(*xs), xs
|
||||
|
||||
def bwd(res, g):
|
||||
static_scalar, *_ = res
|
||||
t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g
|
||||
self.assertIs(type(t_static), ZERO)
|
||||
self.assertFalse(type(t_static_arr) is ZERO)
|
||||
self.assertFalse(type(t_dyn_scalar) is ZERO)
|
||||
self.assertFalse(type(t_dyn_array) is ZERO)
|
||||
self.assertEqual(t_static.shape, ())
|
||||
self.assertEqual(t_static_arr.shape, (2,))
|
||||
return (static_scalar + 90,
|
||||
t_static_arr + 91,
|
||||
t_dyn_scalar + 92,
|
||||
t_dyn_array + 93)
|
||||
|
||||
f.defvjp(fwd, bwd, symbolic_zeros=True)
|
||||
|
||||
def g(dyn_scalar, dyn_array):
|
||||
if maybe_vmap:
|
||||
f_ = _vmap(f)
|
||||
else:
|
||||
f_ = f
|
||||
outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array)
|
||||
return outs[1:]
|
||||
|
||||
def run(primal_ins, cotangent_outs):
|
||||
primal_outs, vjp = jax.vjp(g, *primal_ins)
|
||||
cotangent_ins = vjp(cotangent_outs)
|
||||
return primal_outs, cotangent_ins
|
||||
|
||||
if maybe_jit:
|
||||
run = jax.jit(run)
|
||||
|
||||
scalar_type = jax.Array if maybe_jit or maybe_vmap else float
|
||||
primal_ins = (4., jnp.array([5., 6.]))
|
||||
cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.]))
|
||||
primal_outs, cotangent_ins = run(primal_ins, cotangent_outs)
|
||||
|
||||
primal_out1, primal_out2, primal_out3 = primal_outs
|
||||
self.assertIsInstance(primal_out1, jax.Array)
|
||||
self.assertAllClose(primal_out1, jnp.array([2., 3.]))
|
||||
self.assertIsInstance(primal_out2, scalar_type)
|
||||
self.assertAllClose(primal_out2, 5.)
|
||||
self.assertIsInstance(primal_out3, jax.Array)
|
||||
self.assertAllClose(primal_out3, jnp.array([7., 9.]))
|
||||
|
||||
ct_in1, ct_in2 = cotangent_ins
|
||||
self.assertIsInstance(ct_in1, scalar_type)
|
||||
self.assertAllClose(ct_in1, 99.)
|
||||
self.assertIsInstance(ct_in2, jax.Array)
|
||||
self.assertArraysAllClose(ct_in2, jnp.array([101., 102.]))
|
||||
|
||||
def test_symbolic_zero_custom_vjp_vmap_output(self):
|
||||
@jax.custom_vjp
|
||||
def f(x, y):
|
||||
return x, y
|
||||
|
||||
def fwd(x, y):
|
||||
self.assertTrue(x.perturbed)
|
||||
self.assertFalse(y.perturbed)
|
||||
return f(x.value, y.value), None
|
||||
|
||||
def bwd(_, g):
|
||||
_, ct_y = g
|
||||
self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero)
|
||||
return g
|
||||
|
||||
f.defvjp(fwd, bwd, symbolic_zeros=True)
|
||||
jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3))
|
||||
|
||||
def test_symbolic_zero_custom_vjp_custom_pytree(self):
|
||||
tree_values = custom_derivatives_public.custom_vjp_primal_tree_values
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class Box:
|
||||
def __init__(self_, strict, val):
|
||||
if strict:
|
||||
# make sure we aren't getting special arguments that should only
|
||||
# come up when symbolic_zeros is True
|
||||
self.assertFalse(hasattr(val, 'perturbed'))
|
||||
self_.strict = strict
|
||||
self_.x = val
|
||||
|
||||
def tree_flatten(self_):
|
||||
return [self_.x], self_.strict
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, strict, xs):
|
||||
x, = xs
|
||||
return cls(strict, x)
|
||||
|
||||
x, y = Box(False, jnp.array(72.)), jnp.array(73.)
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(box, y):
|
||||
return box.x * y
|
||||
|
||||
def fwd0(box, y):
|
||||
self.assertTrue(box.x.perturbed)
|
||||
self.assertFalse(y.perturbed)
|
||||
box, y = map(tree_values, [box, y])
|
||||
return f(box, y), (box, y)
|
||||
|
||||
def bwd0(res, g):
|
||||
box, y = res
|
||||
return y * g, box.x * g
|
||||
|
||||
def fwd1(box, y):
|
||||
self.assertFalse(box.x.perturbed)
|
||||
self.assertTrue(y.perturbed)
|
||||
box, y = map(tree_values, [box, y])
|
||||
return f(box, y), (box, y)
|
||||
|
||||
def bwd1(res, g):
|
||||
box, y = res
|
||||
return y * g, box.x * g
|
||||
|
||||
f.defvjp(fwd0, bwd0, symbolic_zeros=True)
|
||||
jax.grad(f, argnums=0)(x, y)
|
||||
f.defvjp(fwd1, bwd1, symbolic_zeros=True)
|
||||
jax.grad(f, argnums=1)(x, y)
|
||||
|
||||
def fwd_strict(box, y):
|
||||
return f(box, y), (box, y)
|
||||
|
||||
def bwd_strict(res, g):
|
||||
box, y = res
|
||||
return y * g, box.x * g
|
||||
|
||||
f.defvjp(fwd_strict, bwd_strict)
|
||||
jax.grad(f)(x, y)
|
||||
|
||||
def transpose_unary(f, x_example):
|
||||
def transposed(y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user