Copybara import of the project:

--
609f6f3e16d21fed34cc5269c54a0d78ac44a8bc by Matthew Johnson <mattjj@google.com>:

fix custom_jvp/vjp closure issues

PiperOrigin-RevId: 337457689
This commit is contained in:
jax authors 2020-10-16 00:21:04 -07:00
parent d7d94ac9ea
commit 4a20eea828
11 changed files with 1394 additions and 737 deletions

129
docs/custom_vjp_update.md Normal file
View File

@ -0,0 +1,129 @@
# `custom_vjp` and `nondiff_argnums` update guide
_mattjj@_
_Oct 14 2020_
This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom
derivative rules for JAX-transformable Python
functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)
notebook.
### What to update
After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments
passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or
containers of `Tracer`s), which basically means to allow for
arbitrarily-transformable code `nondiff_argnums` shouldn't be used for
array-valued arguments. Instead, `nondiff_argnums` should be used only for
non-array values, like Python callables or shape tuples or strings.
Wherever we used to use `nondiff_argnums` for array values, we should just pass
those as regular arguments. In the `bwd` rule, we need to produce values for them,
but we can just produce `None` values to indicate there's no corresponding
gradient value.
For example, here's the **old** way to write `clip_gradient`, which won't work
when `hi` and/or `lo` are `Tracer`s from some JAX transformation.
```python
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, None # no residual values to save
def clip_gradient_bwd(lo, hi, _, g):
return (jnp.clip(g, lo, hi),)
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
```
Here's the **new**, awesome way, which supports arbitrary transformations:
```python
import jax
@jax.custom_vjp # no nondiff_argnums!
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save lo and hi values as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
```
If you use the old way instead of the new way, you'll get a loud error in any
case where something might go wrong (namely when there's a `Tracer` passed into
a `nondiff_argnums` argument).
Here's a case where we actually need `nondiff_argnums` with `custom_vjp`:
```python
from functools import partial
import jax
@partial(jax.custom_vjp, nondiff_argnums)
def skip_app(f, x):
return f(x)
def skip_app_fwd(f, x):
return skip_bwd(f, x), None
def skip_app_bwd(f, _, g):
return (g,)
skip_app.defvjp(skip_app_fwd, skip_app_bwd)
```
### Explanation
Passing `Tracer`s into `nondiff_argnums` arguments was always buggy. While there
were some cases which worked correctly, others would lead to complex and
confusing error messages.
The essence of the bug was that `nondiff_argnums` was implemented in a way that
acted very much like lexical closure. But lexical closure over `Tracer`s wasn't
at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing
`nondiff_argnums` that way was a mistake!
**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure
issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp`
and `custom_vjp` functions and rules can close over `Tracer`s to our hearts'
content. For all non-autodiff transformations, things will Just Work. For
autodiff transformations, we'll get a clear error message about why we can't
differentiate with respect to values over which a `custom_jvp` or `custom_vjp`
closes:
> Detected differentiation of a custom_jvp function with respect to a closed-over
value. That isn't supported because the custom JVP rule only specifies how to
differentiate the custom_jvp function with respect to explicit input parameters.
>
> Try passing the closed-over value into the custom_jvp function as an argument,
and adapting the custom_jvp rule.
In tightening up and robustifying `custom_jvp` and `custom_vjp` in this way, we
found that allowing `custom_vjp` to accept `Tracer`s in its `nondiff_argnums`
would take a significant amount of bookkeeping: we'd need to rewrite the user's
`fwd` function to return the values as residuals, and rewrite the user's `bwd`
function to accept them as normal residuals (rather than accepting them as
special leading arguments, as happens with `nondiff_argnums`). This seems maybe
manageable, until you think through how we have to handle arbitrary pytrees!
Moreover, that complexity isn't necessary: if user code treats array-like
non-differentiable arguments just like regular arguments and residuals,
everything already works. (Before
[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about
involving integer-valued inputs and outputs in autodiff, but after
[#4039](https://github.com/google/jax/pull/4039) those will just work!)
Unlike `custom_vjp`, it was easy to make `custom_jvp` work with
`nondiff_argnums` arguments that were `Tracer`s. So these updates only need to
happen with `custom_vjp`.

File diff suppressed because one or more lines are too long

View File

@ -393,23 +393,24 @@ class Trace:
self.__class__.__name__, self.level, self.sublevel)
def process_call(self, call_primitive, f, tracers, params):
raise NotImplementedError("must override to handle call-like primitives")
msg = (f"{type(self)} must override process_call to handle call-like "
"primitives")
raise NotImplementedError(msg)
def process_map(self, call_primitive, f, tracers, params):
raise NotImplementedError("must override to handle map-like primitives")
msg = (f"{type(self)} must override process_map to handle map-like "
"primitives")
raise NotImplementedError(msg)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
# As a default implementation, drop the custom differentiation rule. This
# behavior is desirable when staging out of the JAX system, but not when
# there are further differentiation transformations to be applied. Override
# this method to allow differentiation to be performed downstream.
del primitive, jvp # Unused.
return fun.call_wrapped(*tracers)
msg = (f"{type(self)} must override process_custom_jvp_call "
"to handle custom_jvp primitives")
raise NotImplementedError(msg)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
# See comment in the above process_custom_jvp_call method.
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
def escaped_tracer_error(detail=None):
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
@ -575,6 +576,14 @@ class EvalTrace(Trace):
return primitive.impl(f, *tracers, **params)
process_map = process_call
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
del primitive, jvp # Unused.
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)
class MainTrace:
level: int

View File

@ -16,6 +16,7 @@
from functools import update_wrapper, reduce, partial
import inspect
import operator as op
from typing import Callable, Sequence, Tuple, Any
from . import core
from . import linear_util as lu
@ -45,38 +46,12 @@ def _resolve_kwargs(fun, args, kwargs):
else:
return ba.args
def _add_args(f, extra_args, left):
return _add_args_(f, tuple(map(wrap_hashably, extra_args)), left)
@lu.transformation
def _add_args_(extra_args, left, *args, **kwargs):
extra_args = tuple([arg.val for arg in extra_args])
args = (extra_args + args) if left else (args + extra_args)
yield (yield args, kwargs)
def _memoize(thunk):
cell = []
saved_state = core.thread_local_state.trace_state.copy()
def memoized():
if not cell:
prev_state = core.thread_local_state.trace_state
core.thread_local_state.trace_state = saved_state
try:
cell.append(thunk())
finally:
core.thread_local_state.trace_state = prev_state
return cell[0]
return memoized
def _initial_style_jaxpr(fun, in_avals):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
return jaxpr, consts
def _initial_style_staging() -> bool:
if config.omnistaging_enabled:
return core.thread_local_state.trace_state.trace_stack.dynamic.level != 0 # type: ignore
else:
return core.thread_local_state.trace_state.initial_style
return core.thread_local_state.trace_state.initial_style
def _sum_tangents(_, x, *xs):
return reduce(ad.add_tangents, xs, x)
@ -208,27 +183,39 @@ class custom_jvp:
raise AttributeError(msg.format(self.__name__))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
is_nondiff = [False] * len(args)
for i in self.nondiff_argnums: is_nondiff[i] = True
args = [_stop_gradient(x) if b else x for b, x in zip(is_nondiff, args)]
dyn_argnums = [i for i, b in enumerate(is_nondiff) if not b]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
args = [_stop_gradient(x) if i in self.nondiff_argnums else x
for i, x in enumerate(args)]
diff_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args)
static_args = [args[i] for i in self.nondiff_argnums]
jvp = _add_args(lu.wrap_init(self.jvp), static_args, left=True)
jvp = _add_args(lu.wrap_init(self.jvp), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
jvp = lu.wrap_init(self.jvp)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
if _initial_style_staging():
out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat)
out_tree = out_tree1()
else:
out_flat = custom_jvp_call(flat_fun, flat_jvp, *args_flat)
if config.omnistaging_enabled:
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
else:
if _initial_style_staging():
out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat) # type: ignore
out_tree = out_tree1()
else:
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
return tree_unflatten(out_tree, out_flat)
def _add_args(f, extra_args):
return _add_args_(f, tuple(map(wrap_hashably, extra_args)))
@lu.transformation
def _add_args_(extra_args, *args, **kwargs):
extra_args = tuple([arg.val for arg in extra_args])
all_args = (extra_args + args)
yield (yield all_args, kwargs)
@lu.transformation_with_aux
def _flatten_jvp(in_tree, *args):
primals_in, tangents_in = split_list(args, [len(args) // 2])
@ -264,6 +251,8 @@ def _flatten_jvp(in_tree, *args):
yield primals_out + tangents_out, out_tree
class CustomJVPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive
def bind(self, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
@ -272,51 +261,63 @@ class CustomJVPCallPrimitive(core.CallPrimitive):
jvp, env_trace_todo2 = core.process_env_traces(
jvp, self, top_trace and top_trace.level, ())
tracers = map(top_trace.full_raise, args) # type: ignore
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
with core.maybe_new_sublevel(top_trace):
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if env_trace_todo:
raise core.UnexpectedTracerError
return map(core.full_lower, outs)
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, _, *args):
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, params):
return trace.post_process_custom_jvp_call(out_tracers, params)
def _apply_todos(todos, outs):
todos_list = list(todos)
while todos_list:
outs = map(core.full_lower, todos_list.pop()(outs))
return outs
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
custom_jvp_call = custom_jvp_call_p.bind
def custom_jvp_call_jaxpr(fun, jvp, *args):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
jvp_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
return custom_jvp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk)
def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr, **_):
def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params):
del params # other params ignored because we're just executing the primal fun
return core.jaxpr_as_fun(fun_jaxpr)(*args)
def _custom_jvp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params):
del args, params
return fun_jaxpr.out_avals
custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p
def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, jvp_jaxpr_thunk):
jvp_jaxpr = jvp_jaxpr_thunk()
tangents = map(ad.instantiate_zeros, tangents)
def _custom_jvp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
raise ad.CustomJVPException()
jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() # consts can be tracers!
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom jvp rules don't
# currently handle float0s
tangents = ad.replace_float0s(primals, tangents)
outs = core.jaxpr_as_fun(jvp_jaxpr)(*primals, *tangents)
args_dot = map(ad.replace_float0s, args, args_dot)
outs = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *args, *args_dot)
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
tangents_out = ad.recast_to_float0(primals_out, tangents_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp
def _custom_jvp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, jvp_jaxpr_thunk):
def _custom_jvp_call_jaxpr_vmap(
args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int):
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
@ -325,21 +326,22 @@ def _custom_jvp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, jvp_jaxpr_thunk):
in_batched = [d is not not_mapped for d in in_dims]
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(fun_jaxpr, size, in_batched, False)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = []
out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk
@_memoize
@pe._memoize
def batched_jvp_jaxpr_thunk():
jvp_jaxpr = jvp_jaxpr_thunk()
jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, in_batched * 2, False)
primals_batched, tangents_batched = split_list(all_batched, [num_out])
out_batched = map(op.or_, primals_batched, tangents_batched)
out_dims2.append([0 if b else not_mapped for b in out_batched])
batched_jvp_jaxpr, _ = batching.batch_jaxpr(jvp_jaxpr, size, in_batched * 2,
out_batched * 2)
return batched_jvp_jaxpr
batched_jvp_jaxpr, _ = batching.batch_jaxpr(
jvp_jaxpr, size, in_batched * 2, out_batched * 2)
return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
batched_outs = custom_jvp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr, jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk)
*args, fun_jaxpr=batched_fun_jaxpr,
jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
@ -350,8 +352,9 @@ xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk):
del jvp_jaxpr_thunk
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
num_consts):
del jvp_jaxpr_thunk, num_consts
return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)
ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
@ -451,14 +454,12 @@ class custom_vjp:
raise AttributeError(msg.format(self.__name__))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
is_nondiff = [False] * len(args)
for i in self.nondiff_argnums: is_nondiff[i] = True
args = [_stop_gradient(x) if b else x for b, x in zip(is_nondiff, args)]
dyn_argnums = [i for i, b in enumerate(is_nondiff) if not b]
for i in self.nondiff_argnums: _check_for_tracers(args[i])
dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
static_args = [args[i] for i in self.nondiff_argnums]
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
bwd = _add_args(lu.wrap_init(self.bwd), static_args, left=True)
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
@ -467,17 +468,39 @@ class custom_vjp:
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
if _initial_style_staging():
out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
out_tree = out_tree()
else:
out_flat = custom_vjp_call(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
if config.omnistaging_enabled:
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
else:
if _initial_style_staging():
out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd, # type: ignore
*args_flat, out_trees=out_trees)
out_tree = out_tree()
else:
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
return tree_unflatten(out_tree, out_flat)
@partial(partial, tree_map)
def _check_for_tracers(x):
if isinstance(x, core.Tracer):
msg = ("Found a JAX Tracer object passed as an argument to a custom_vjp "
"function in a position indicated by nondiff_argnums as "
"non-differentiable. Tracers cannot be passed as non-differentiable "
"arguments to custom_vjp functions; instead, nondiff_argnums should "
"only be used for arguments that can't be or contain JAX tracers, "
"e.g. function-valued arguments. In particular, array-valued "
"arguments should typically not be indicated as nondiff_argnums. "
"\n\n"
"This behavior recently changed in JAX. "
"See https://github.com/google/jax/blob/master/docs/custom_vjp_update.md "
"for more information.")
raise core.UnexpectedTracerError(msg)
@lu.transformation_with_aux
def _flatten_fwd(in_tree, *args):
py_args = tree_unflatten(in_tree, args)
@ -524,31 +547,29 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
class CustomVJPCallPrimitive(core.CallPrimitive):
initial_style: core.Primitive
def bind(self, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
outs = fun.call_wrapped(*args)
else:
tracers = map(top_trace.full_raise, args)
fun, env_trace_todo1 = core.process_env_traces(
fun, self, top_trace and top_trace.level, ())
fwd, env_trace_todo2 = core.process_env_traces(
fwd, self, top_trace and top_trace.level, ())
tracers = map(top_trace.full_raise, args) # type: ignore
with core.maybe_new_sublevel(top_trace):
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
out_trees=out_trees)
return map(core.full_lower, outs)
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
def impl(self, fun, fwd, bwd, *args, out_trees):
del fwd, bwd, out_trees
return fun.call_wrapped(*args)
def post_process(self, trace, out_tracers, params):
return trace.post_process_custom_vjp_call(out_tracers, params)
custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
custom_vjp_call = custom_vjp_call_p.bind
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
fwd_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(fwd, in_avals))
return custom_vjp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd,
out_trees=out_trees)
def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
return core.jaxpr_as_fun(fun_jaxpr)(*args)
@ -560,27 +581,35 @@ custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr')
custom_vjp_call_jaxpr_p.multiple_results = True
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p
def _custom_vjp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, fwd_jaxpr_thunk,
bwd, out_trees):
tangents = map(ad.instantiate_zeros, tangents)
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
raise ad.CustomVJPException()
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers!
out_tree, res_tree = out_trees()
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
tangents = ad.replace_float0s(primals, tangents)
fwd_jaxpr = fwd_jaxpr_thunk()
out_tree, res_tree = out_trees()
res_and_primals_out = core.jaxpr_as_fun(fwd_jaxpr)(*primals)
args_dot = map(ad.replace_float0s, args, args_dot)
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
tangents_out = ad.custom_lin_p.bind(
*res, *tangents, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)
tangents_out = ad.recast_to_float0(primals_out, tangents_out)
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, fwd_jaxpr_thunk,
bwd, out_trees):
def _custom_vjp_call_jaxpr_vmap(
args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
@ -590,25 +619,26 @@ def _custom_vjp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, fwd_jaxpr_thunk,
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = []
@_memoize
@pe._memoize
def batched_fwd_jaxpr_thunk():
fwd_jaxpr = fwd_jaxpr_thunk()
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(fwd_jaxpr, size, in_batched, False)
out_dims2.append([0 if b else not_mapped for b in out_batched])
return batched_fwd_jaxpr
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
fwd_in_dims = [0 if b else not_mapped for b in in_batched]
fwd_out_dims = lambda: out_dims2[0]
# TODO: Support collectives in custom_vjp?
# TODO(mattjj,apaszke): Support collectives in custom_vjp?
batched_bwd = batching.batch_fun(bwd, fwd_out_dims, fwd_in_dims,
axis_name='__unused_axis_name', sum_match=True)
batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
out_trees=out_trees)
out_trees=out_trees, num_consts=num_consts)
out_dims = out_dims2[0] if out_dims2 else out_dims1
out_dims = out_dims[:len(batched_outs)] # TODO(mattjj): remove after #4008
if not config.omnistaging_enabled:
out_dims = out_dims[:len(batched_outs)]
return batched_outs, out_dims
batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
@ -620,16 +650,16 @@ batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_jaxpr, custom_jvp_call
global _initial_style_jaxpr, custom_vjp_call_jaxpr, custom_jvp_call_jaxpr
def _initial_style_jaxpr(fun, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
bottom=True, stage_out=False) # type: ignore
assert not any(isinstance(c, core.Tracer) for c in consts)
return core.ClosedJaxpr(jaxpr, consts)
return jaxpr, consts
def bind(self, fun, jvp, *args):
def jvp_bind(self, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = core.process_env_traces(
@ -646,5 +676,43 @@ def omnistaging_disabler() -> None:
if env_trace_todo:
raise core.UnexpectedTracerError
return map(core.full_lower, outs)
CustomJVPCallPrimitive.bind = bind # type: ignore
custom_jvp_call = custom_jvp_call_p.bind
CustomJVPCallPrimitive.bind = jvp_bind # type: ignore
def jvp_post_process(self, trace, out_tracers, params):
raise core.UnexpectedTracerError
CustomJVPCallPrimitive.post_process = jvp_post_process # type: ignore
def vjp_bind(self, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
outs = fun.call_wrapped(*args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
out_trees=out_trees)
return map(core.full_lower, outs)
CustomVJPCallPrimitive.bind = vjp_bind # type: ignore
def vjp_post_process(self, trace, out_tracers, params):
raise core.UnexpectedTracerError
CustomVJPCallPrimitive.post_process = vjp_post_process # type: ignore
def custom_jvp_call_jaxpr(fun: Callable, jvp: Callable, *args):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
jvp_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
return custom_jvp_call_jaxpr_p.bind(
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
fwd_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(fwd, in_avals))
return custom_vjp_call_jaxpr_p.bind(
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
num_consts=len(consts))

View File

@ -560,6 +560,26 @@ class TensorFlowTrace(core.Trace):
def post_process_map(self, map_primitive, out_tracers, params):
raise NotImplementedError("post_process_map")
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del jvp # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_jvp_call(self, out_tracers, params):
assert False # unreachable assuming jax2tf runs with clean trace state
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
# Drop the custom differentiation rule and act like a call primitive. This
# behavior is desirable because jax2tf stages code out of the JAX system, so
# there are no more JAX differentiation transformations to be applied.
del fwd, bwd, out_trees # Unused.
return self.process_call(core.call_p, fun, tracers, {})
def post_process_custom_vjp_call(self, out_tracers, params):
assert False # unreachable assuming jax2tf runs with clean trace state
def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
# Returns the primitive implementation and whether the implementation
# takes abstract values (see definition of tf_impl_with_avals)
@ -1808,7 +1828,8 @@ tf_impl[lax_linalg.triangular_solve_p] = _triangular_solve
def _custom_jvp_call_jaxpr(*args: TfVal,
fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable) -> Sequence[TfVal]:
jvp_jaxpr_thunk: Callable,
num_consts: int) -> Sequence[TfVal]:
# TODO(necula): ensure that there is no AD transformation in scope
return _interpret_jaxpr(fun_jaxpr, *args)

View File

@ -153,8 +153,14 @@ class JetTrace(core.Trace):
return map(partial(JetTracer, trace), primals, series)
return out, todo
def join(self, xt, yt):
assert False # TODO?
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
# TODO(mattjj): don't just ignore custom jvp rules?
del primitive, jvp # Unused.
return fun.call_wrapped(*tracers)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)
class ZeroTerm(object): pass

View File

@ -146,16 +146,17 @@ def unpair_pval(pval):
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)
def replace_float0s(primals, tangents):
return [core.zeros_like_float0(tangent, dtype(primal))
if dtype(tangent) is float0 else tangent
for primal, tangent in zip(primals, tangents)]
def replace_float0s(primal, tangent):
if dtype(tangent) is float0:
return core.zeros_like_float0(tangent, dtype(primal))
else:
return tangent
def recast_to_float0(primals, tangents):
return [Zero(get_aval(primal).at_least_vspace())
if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0
else tangent
for primal, tangent in zip(primals, tangents)]
def recast_to_float0(primal, tangent):
if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
return Zero(get_aval(primal).at_least_vspace())
else:
return tangent
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
@ -314,12 +315,15 @@ class JVPTrace(Trace):
tangents_in = map(instantiate_zeros, tangents_in)
# Cast float0 to zeros with the primal dtype because custom jvp rules don't
# currently handle float0s
tangents_in = replace_float0s(primals_in, tangents_in)
tangents_in = map(replace_float0s, primals_in, tangents_in)
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
tangents_out = recast_to_float0(primals_out, tangents_out)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
def post_process_custom_jvp_call(self, out_tracers, params):
raise CustomJVPException()
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
tangents_in = map(instantiate_zeros, tangents_in)
@ -330,9 +334,12 @@ class JVPTrace(Trace):
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
avals_out=avals_out)
tangents_out = recast_to_float0(primals_out, tangents_out)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)
def post_process_custom_vjp_call(self, out_tracers, params):
raise CustomVJPException()
def join(self, xt, yt):
xz, yz = type(xt) is Zero, type(yt) is Zero
if xz == yz:
@ -625,8 +632,7 @@ def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
cts_in = bwd.call_wrapped(*res, *cts_out)
cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure
return [None] * num_res + cts_in_flat
return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose
@ -696,6 +702,28 @@ def defvjp2(prim, *vjps):
defvjp_all(prim, vjpmaker)
class CustomJVPException(Exception):
def __init__(self):
# TODO(mattjj): track source provenance on AD tracers, improve error
msg = ("Detected differentiation of a custom_jvp function with respect to "
"a closed-over value. That isn't supported because the custom JVP "
"rule only specifies how to differentiate the custom_jvp function "
"with respect to explicit input parameters. Try passing the "
"closed-over value into the custom_jvp function as an argument, and "
"adapting the custom_jvp rule.")
super().__init__(msg)
class CustomVJPException(Exception):
def __init__(self):
# TODO(mattjj): track source provenance on AD tracers, improve error
msg = ("Detected differentiation of a custom_vjp function with respect to "
"a closed-over value. That isn't supported because the custom VJP "
"rule only specifies how to differentiate the custom_vjp function "
"with respect to explicit input parameters. Try passing the "
"closed-over value into the custom_vjp function as an argument, and "
"adapting the custom_vjp fwd and bwd rules.")
super().__init__(msg)
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global jvp_jaxpr

View File

@ -215,11 +215,19 @@ class BatchTrace(Trace):
out_dims = out_dims[:len(out_dims) // 2]
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
def post_process_custom_jvp_call(self, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
main = self.main
def todo(vals):
trace = BatchTrace(main, core.cur_sublevel())
return map(partial(BatchTracer, trace), vals, dims)
return vals, todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
# TODO: Support collectives in custom_vjp?
# TODO(mattjj,apaszke): support collectives in custom_vjp?
bwd = batch_fun(bwd, out_dims2, in_dims,
axis_name='__unused_axis_name', sum_match=True)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
@ -370,11 +378,11 @@ def _promote_aval_rank(sz, aval):
else:
return ShapedArray((sz,) + aval.shape, aval.dtype)
def batch_jaxpr(jaxpr, size, batched, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
def batch_jaxpr(closed_jaxpr, size, batched, instantiate):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, batched_out = batched_traceable(f, size, batched, instantiate)
avals_in = [_promote_aval_rank(size, a) if b else a
for a, b in zip(jaxpr.in_avals, batched)]
for a, b in zip(closed_jaxpr.in_avals, batched)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), batched_out()

View File

@ -31,6 +31,10 @@ from ..config import config
map = safe_map
zip = safe_zip
def _initial_style_jaxpr(fun, in_avals):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
################################################################################
# Reverse call primitive
################################################################################
@ -118,10 +122,9 @@ def _flatten_ivjp(in_tree, out_tree, *args):
def _custom_ivjp(fun, ivjp, args):
in_avals = [raise_to_shaped(get_aval(x)) for x in args]
fun_jaxpr = custom_derivatives._initial_style_jaxpr(fun, in_avals)
fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
try:
ivjp_jaxpr = custom_derivatives._initial_style_jaxpr(
ivjp, in_avals + fun_jaxpr.out_avals * 2)
ivjp_jaxpr = _initial_style_jaxpr(ivjp, in_avals + fun_jaxpr.out_avals * 2)
except RecursionError:
raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__))
return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr,
@ -310,3 +313,15 @@ def get_primitive_inverse(p):
def definverse(primitive, inverse_rule):
primitive_inverses[primitive] = inverse_rule
return inverse_rule
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_jaxpr, custom_jvp_call
def _initial_style_jaxpr(fun, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
bottom=True, stage_out=False) # type: ignore
assert not any(isinstance(c, core.Tracer) for c in consts)
return core.ClosedJaxpr(jaxpr, consts)

View File

@ -179,7 +179,7 @@ class JaxprTrace(Trace):
else PartialVal.unknown(mapped_aval(pval[0]))
for pval, is_mapped in zip(in_pvals, params['mapped_invars'])]
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
f, in_pvals, partial(primitive.bind, **params))
f, in_pvals, partial(primitive.bind, **params), instantiate=False)
if primitive.map_primitive:
unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
out_pvals = [pval if pval.is_known()
@ -273,10 +273,11 @@ class JaxprTrace(Trace):
post_process_map = post_process_call
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]]):
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
instantiate: bool):
"""Partially evaluate f on a sequence of PartialVals."""
in_avals, in_consts = unzip2(pvals)
f = trace_to_subjaxpr(f, self.main, False)
f = trace_to_subjaxpr(f, self.main, instantiate)
f, aux = partial_eval_wrapper(f, tuple(in_avals))
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
@ -284,6 +285,85 @@ class JaxprTrace(Trace):
env_tracers = map(self.full_raise, env)
return jaxpr, out_pvs, consts, env_tracers
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
fun = trace_to_subjaxpr(fun, self.main, True)
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
out_flat = prim.bind(fun, jvp, *in_consts)
out_avals, jaxpr, env = aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
env_tracers = map(self.full_raise, env)
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *tracers)
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
@_memoize
def jvp_jaxpr_thunk():
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts) + len(env)),
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
def post_process_custom_jvp_call(self, out_tracers, params):
# This path should only be reachable if we expose a partial eval API
# unrelated to autodiff, since we raise an error when differentiation with
# respect to values over which a custom_jvp function closes is detected.
raise NotImplementedError # TODO(mattjj)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
fun = trace_to_subjaxpr(fun, self.main, True)
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
out_avals, jaxpr, env = aux()
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
env_tracers = map(self.full_raise, env)
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
const_tracers = map(self.new_instantiated_const, consts)
in_tracers = (*const_tracers, *env_tracers, *tracers)
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
@_memoize
def fwd_jaxpr_thunk():
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
out_avals, jaxpr, env = aux()
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
return converted_jaxpr, (*consts, *env)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts) + len(env),
bwd=bwd, out_trees=out_trees),
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers
def post_process_custom_vjp_call(self, out_tracers, params):
# This path should only be reachable if we expose a partial eval API
# unrelated to autodiff, since we raise an error when differentiation with
# respect to values over which a custom_vjp function closes is detected.
raise NotImplementedError # TODO(mattjj)
@lu.transformation_with_aux
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
@ -551,6 +631,15 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr):
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
return lifted_jaxpr
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
core.skip_checks or core.check_jaxpr(jaxpr)
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
core.skip_checks or core.check_jaxpr(converted_jaxpr)
return converted_jaxpr
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
@ -658,11 +747,11 @@ def _remat_partial_eval(trace, _, f, tracers, params):
in_pvals = [t.pval for t in instantiated_tracers]
if config.omnistaging_enabled:
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
else:
with core.initial_style_staging(): # type: ignore
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
# Convert consts to inputs, since they may contain Tracer instances.
jaxpr = convert_constvars_jaxpr(jaxpr)
@ -847,7 +936,6 @@ class JaxprStackFrame:
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
# core.skip_checks or core.check_jaxpr(jaxpr)
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
@ -1000,6 +1088,63 @@ class DynamicJaxprTrace(core.Trace):
def post_process_map(self, map_primitive, out_tracers, params):
assert False # unreachable
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
in_avals = [t.aval for t in tracers]
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
jvp_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
dict(fun_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
num_consts=len(consts)),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
def post_process_custom_jvp_call(self, out_tracers, params):
assert False # unreachable
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
in_avals = [t.aval for t in tracers]
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
fwd_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
dict(fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(consts),
bwd=bwd, out_trees=out_trees),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
def post_process_custom_vjp_call(self, out_tracers, params):
assert False # unreachable
def _memoize(thunk):
cell = []
saved_state = core.thread_local_state.trace_state.copy()
def memoized():
if not cell:
prev_state = core.thread_local_state.trace_state
core.thread_local_state.trace_state = saved_state
try:
cell.append(thunk())
finally:
core.thread_local_state.trace_state = prev_state
return cell[0]
return memoized
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
assert config.omnistaging_enabled

View File

@ -2793,10 +2793,30 @@ class CustomJVPTest(jtu.JaxTestCase):
expected = 2.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.jit(foo))(3.)
expected = 2.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.grad(foo))(3.)
expected = 2.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.grad(foo))(3.)
expected = 0.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.grad(api.jit(foo)))(3.)
expected = 0.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.jit(api.grad(foo)))(3.)
expected = 0.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.grad(api.grad(foo)))(3.)
expected = 0.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_initial_style_vmap(self):
@api.custom_jvp
def f(x):
@ -2816,11 +2836,38 @@ class CustomJVPTest(jtu.JaxTestCase):
expected = 3. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.vmap(api.jit(foo))(jnp.ones(3))
expected = 3. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.vmap(foo))(jnp.ones(3))
expected = 3. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracers_error_message(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
@api.custom_jvp
def g(y):
@ -2830,10 +2877,8 @@ class CustomJVPTest(jtu.JaxTestCase):
g.defjvp(g_jvp)
return g(1.)
self.assertRaises(
core.UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,)))
self.assertRaises(
core.UnexpectedTracerError, lambda: api.grad(f)(3.))
self.assertRaises(ad.CustomJVPException, lambda: api.jvp(f, (3.,), (1.,)))
self.assertRaises(ad.CustomJVPException, lambda: api.grad(f)(3.))
def test_nondiff_arg(self):
@partial(api.custom_jvp, nondiff_argnums=(0,))
@ -2869,6 +2914,25 @@ class CustomJVPTest(jtu.JaxTestCase):
expected = (6., 5.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_hiding_jvp_tracer(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
@partial(api.custom_jvp, nondiff_argnums=(0,))
def g(h, x):
return h(x)
@g.defjvp
def g_jvp(h, primals, tangents):
x, = primals
t, = tangents
return g(h, x), 2. * t
h = lambda y: x + y # capture x
return g(h, x)
with self.assertRaisesRegex(ad.CustomJVPException, "Detected differentiation"):
api.jvp(f, (2.,), (1.,))
def test_vmap_axes(self):
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
@ -3123,6 +3187,77 @@ class CustomJVPTest(jtu.JaxTestCase):
for ans in results:
self.assertAllClose(ans, expected)
def test_nondiff_argnums_vmap_tracer(self):
# https://github.com/google/jax/issues/3964
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
@partial(jax.custom_jvp, nondiff_argnums=(0, 2))
def sample(shape, param, seed):
return jax.random.uniform(key=seed, shape=shape, minval=param)
@sample.defjvp
def sample_jvp(shape, seed, primals, tangents):
param, = primals
dparam, = tangents
dparam = jnp.broadcast_to(dparam, shape)
samples = sample(shape, param, seed)
return samples, samples * dparam # dummy jvp for proof of concept
# check these don't crash
jax.vmap(lambda seed: sample((2,3), 1., seed))(
jax.random.split(jax.random.PRNGKey(1), 10))
jax.jvp(lambda x: sample((2, 3), x, jax.random.PRNGKey(1)),
(1.,), (1.,))
def test_fun_with_nested_calls_2(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def call(f, *args):
f = api.custom_jvp(f)
f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents)))
return f(*args)
def fun_with_nested_calls_2(x):
def bar(y):
def baz(w):
q = call(lambda x: y, x)
q = q + call(lambda: y)
q = q + call(lambda y: w + y, y)
q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q
return q
return api.jit(baz)(x)
return call(bar, x)
# test these don't crash
self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.),
fun_with_nested_calls_2(3.))
api.vmap(fun_with_nested_calls_2)(jnp.arange(3.))
def test_closure_with_vmap(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
# https://github.com/google/jax/issues/3822
alpha = np.float32(2.)
def sample(seed):
@api.custom_jvp
def f(alpha):
return jax.random.gamma(seed, alpha, shape=[])
@f.defjvp
def f_jvp(primal, tangent):
alpha = primal
dalpha = tangent
sample = f(alpha)
partial_alpha = lax.random_gamma_grad(alpha, sample)
return sample, partial_alpha * dalpha
return f(alpha)
api.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3)) # don't crash
def test_float0(self):
@api.custom_jvp
def f(x, y):
@ -3157,6 +3292,53 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertArraysEqual(api.jvp(foo, primals, tangents),
(primals, expected_tangents))
def test_remat(self):
@api.custom_jvp
def f(x):
return jnp.sin(x)
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return f(x), 2 * jnp.cos(x) * g
f.defjvp(f_jvp)
@api.remat
def g(x):
return f(f(x))
ans = g(2.)
expected = np.sin(np.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(g)(2.)
expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_higher_order(self):
@api.custom_jvp
def f(x):
return jnp.sin(x)
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return f(x), 2 * jnp.cos(x) * g
f.defjvp(f_jvp)
def g(x):
return f(f(x))
ans = api.grad(api.grad(api.remat(g)))(2.)
expected = api.grad(api.grad(g))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.remat(api.grad(g)))(2.)
expected = api.grad(api.grad(g))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.grad(api.grad(api.remat(g))))(2.)
expected = api.grad(api.grad(api.grad(g)))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
class CustomVJPTest(jtu.JaxTestCase):
@ -3395,6 +3577,12 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_tracer(self):
# This test is now skipped because we decided not to support this behavior
# anymore (namely, nondiff args can't be tracers), but
# test_closed_over_tracer is a replacement test for analogous behavior that
# we do support
raise unittest.SkipTest("removed support for tracers in nondiff args")
@partial(api.custom_vjp, nondiff_argnums=(0,))
def f(x, y):
return x * y
@ -3416,6 +3604,56 @@ class CustomVJPTest(jtu.JaxTestCase):
expected = jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracer(self):
# This test is similar to test_nondiff_arg_tracer except it uses lexical
# closure rather than the nondiff_argnums mechanism. We decided to disallow
# tracers in nondiff_argnums to greatly simplify bookkeeping while still
# supporting the cases for which it is necessary.
def outer(x):
@api.custom_vjp
def f(y):
return x * y
def f_fwd(y):
return f(y), jnp.cos(y)
def f_rev(cos_y, g):
return (cos_y * g,)
f.defvjp(f_fwd, f_rev)
return f
@jit
def g(x, y):
return outer(x)(y)
ans = g(2, 3.)
expected = 6.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(g, 1)(2., 3.)
expected = jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_tracer_error(self):
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
# we're testing for the error message that that usage pattern now raises.
@partial(api.custom_vjp, nondiff_argnums=(0,))
def f(x, y):
return x * y
def f_fwd(x, y):
return f(x, y), jnp.cos(y)
def f_rev(x, cos_y, g):
return (cos_y * g,)
f.defvjp(f_fwd, f_rev)
@jit
def g(x, y):
return f(x, y)
with self.assertRaisesRegex(core.UnexpectedTracerError, "custom_vjp"):
_ = g(2, 3.)
with self.assertRaisesRegex(core.UnexpectedTracerError, "custom_vjp"):
_ = api.grad(g, 1)(2., 3.)
def test_vmap_axes(self):
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
@ -3482,6 +3720,11 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
def test_nondiff_argnums_stop_gradient(self):
# This test is now skipped because we decided not to support this behavior
# anymore (namely, nondiff args can't be tracers), but test_clip_gradient is
# a replacement showing behavior we do support.
raise unittest.SkipTest("removed support for tracers in nondiff args")
# https://github.com/google/jax/issues/2784
@partial(api.custom_vjp, nondiff_argnums=(0, 1))
def _clip_gradient(lo, hi, x):
@ -3503,6 +3746,29 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.grad(clip_gradient)(1.) # doesn't crash
def test_clip_gradient(self):
# https://github.com/google/jax/issues/2784
@api.custom_vjp
def _clip_gradient(lo, hi, x):
return x # identity function when not differentiating
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi,)
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi),)
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
def clip_gradient(x):
lo = -0.1
hi = x + 0.1
return _clip_gradient(lo, hi, x)
g = jax.grad(clip_gradient)(0.1) # doesn't crash
self.assertAllClose(g, jnp.array(0.2))
def test_nestable_vjp(self):
# Verify that https://github.com/google/jax/issues/3667 is resolved.
def f(x):
@ -3536,6 +3802,94 @@ class CustomVJPTest(jtu.JaxTestCase):
y, = z(1.0)(3.0)
self.assertAllClose(y, jnp.array(6.0))
def test_initial_style_vmap_2(self):
# https://github.com/google/jax/issues/4173
x = jnp.ones((10, 3))
# Create the custom function
@api.custom_vjp
def custom_fun(x):
return x.sum()
def forward(x):
return x.sum(), (jnp.ones_like(x),)
def backward(res, g):
return g*res[0],
custom_fun.defvjp(forward, backward)
def train_fun(x):
def summed_fun(x):
return api.vmap(custom_fun)(x).sum()
return api.grad(summed_fun)(x)
def scan_body(carry, inputs):
x = carry
return carry, train_fun(x)
scan_range = jnp.arange(4)
lax.scan(scan_body, x, scan_range) # don't crash
def test_bwd_closes_over_tracer(self):
def f(y):
@jax.custom_vjp
def f(x):
return 2. * jnp.sin(x)
def fwd(x):
return f(x), ()
def bwd(_, g):
return (2. * jnp.cos(y) * g,) # capture!
f.defvjp(fwd, bwd)
return jax.grad(f)(1.)
ans = jax.jit(f)(2.)
self.assertAllClose(ans, 2. * jnp.cos(2.))
ans = jax.vmap(f)(jnp.arange(3.))
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
ans = jax.jit(jax.vmap(f))(jnp.arange(3.))
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
ans = jax.vmap(jax.jit(f))(jnp.arange(3.))
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
ans = jax.grad(f)(4.)
self.assertAllClose(ans, -2. * jnp.sin(4.))
def test_fwd_closes_over_tracer(self):
def f(y):
@jax.custom_vjp
def f(x):
return 2. * jnp.sin(x)
def fwd(x):
return f(x), y
def bwd(y, g):
return (2. * jnp.cos(y) * g,) # capture!
f.defvjp(fwd, bwd)
return jax.grad(f)(1.)
ans = jax.jit(f)(2.)
self.assertAllClose(ans, 2. * jnp.cos(2.))
ans = jax.vmap(f)(jnp.arange(3.))
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
ans = jax.jit(jax.vmap(f))(jnp.arange(3.))
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
ans = jax.vmap(jax.jit(f))(jnp.arange(3.))
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
ans = jax.grad(f)(4.)
self.assertAllClose(ans, -2. * jnp.sin(4.))
def test_float0(self):
@api.custom_vjp
def f(x, _):
@ -3571,6 +3925,120 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y),
(2., np.zeros(shape=(), dtype=float0)))
def test_remat(self):
@api.custom_vjp
def f(x):
return jnp.sin(x)
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
@api.remat
def g(x):
return f(f(x))
ans = g(2.)
expected = np.sin(np.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(g)(2.)
expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_higher_order(self):
@api.custom_vjp
def f(x):
return jnp.sin(x)
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
def g(x):
return f(f(x))
ans = api.grad(api.grad(api.remat(g)))(2.)
expected = api.grad(api.grad(g))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.remat(api.grad(g)))(2.)
expected = api.grad(api.grad(g))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.grad(api.grad(api.remat(g))))(2.)
expected = api.grad(api.grad(api.grad(g)))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_bwd_nones(self):
@api.custom_vjp
def f(x, y):
return x * jnp.sin(y)
def f_fwd(x, y):
return f(x, y), jnp.cos(y)
def f_rev(cos, g):
return (None, 2 * cos * g)
f.defvjp(f_fwd, f_rev)
ans = api.grad(lambda x: f(x, x))(3.)
expected = 2 * jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_bwd_nones_vmap(self):
@api.custom_vjp
def f(x, y):
return x * jnp.sin(y)
def f_fwd(x, y):
return f(x, y), jnp.cos(y)
def f_rev(cos, g):
return (None, 2 * cos * g)
f.defvjp(f_fwd, f_rev)
ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.))
expected = 2 * jnp.cos(jnp.arange(3.))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_bwd_nones_pytree(self):
@api.custom_vjp
def f(xs, y):
x1, x2 = xs
return x1 * x2 * jnp.sin(y)
def f_fwd(xs, y):
return f(xs, y), jnp.cos(y)
def f_rev(cos, g):
return (None, 2 * cos * g)
f.defvjp(f_fwd, f_rev)
ans = api.grad(lambda x: f((x, x), x))(3.)
expected = 2 * jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_vjp_closure_4521(self):
# https://github.com/google/jax/issues/4521
@api.custom_vjp
def g(x, y):
return None
def g_fwd(x, y):
return None, y
def g_bwd(residuals, z_bar):
assert False
g.defvjp(g_fwd, g_bwd)
def f(xs, y):
v_g = api.vmap(g, in_axes=(0, None), out_axes=None)
v_g(xs, y)
def scan_body(xs, _):
y = jnp.zeros(1)
_, vjp_f = api.vjp(f, xs, y)
vjp_f(None)
return xs, None
lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash
def test_float0_bwd_none(self):
@api.custom_vjp
def f(i, x):
@ -3791,35 +4259,6 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
self.assertAllClose(grad_ans, 3. * 4. + np.cos(np.sin(3. * 4)),
check_dtypes=False)
# TODO
# def test_defjvp_closure_error(self):
# def foo(x):
# @api.custom_transforms
# def bar(y):
# return x * y
# api.defjvp(bar, lambda y_dot, ans, y: x * y)
# return bar(x)
# jtu.check_raises(
# lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
# "Detected differentiation with respect to closed-over values with "
# "custom JVP rule, which isn't supported.")
# TODO
# def test_defvjp_closure_error(self):
# def foo(x):
# @api.custom_transforms
# def bar(y):
# return x * y
# api.defvjp(bar, lambda g, ans, y: x * y)
# return bar(x)
# jtu.check_raises(
# lambda: grad(foo)(1.,), ValueError,
# "Detected differentiation w.r.t. variables from outside "
# "the scope of <jax.custom_transforms function bar>, but defvjp and "
# "defvjp_all only support differentiation w.r.t. positional arguments.")
def test_custom_transforms_eval_with_pytrees(self):
@api.custom_transforms
def f(x):
@ -3930,8 +4369,6 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
class BufferDonationTest(jtu.JaxTestCase):
# === pmap ===
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_pmap_donate_argnums_invalidates_input(self):
move = api.pmap(lambda x: x + x - x, donate_argnums=0)