mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Copybara import of the project:
-- 609f6f3e16d21fed34cc5269c54a0d78ac44a8bc by Matthew Johnson <mattjj@google.com>: fix custom_jvp/vjp closure issues PiperOrigin-RevId: 337457689
This commit is contained in:
parent
d7d94ac9ea
commit
4a20eea828
129
docs/custom_vjp_update.md
Normal file
129
docs/custom_vjp_update.md
Normal file
@ -0,0 +1,129 @@
|
||||
# `custom_vjp` and `nondiff_argnums` update guide
|
||||
_mattjj@_
|
||||
_Oct 14 2020_
|
||||
|
||||
This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom
|
||||
derivative rules for JAX-transformable Python
|
||||
functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)
|
||||
notebook.
|
||||
|
||||
### What to update
|
||||
|
||||
After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments
|
||||
passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or
|
||||
containers of `Tracer`s), which basically means to allow for
|
||||
arbitrarily-transformable code `nondiff_argnums` shouldn't be used for
|
||||
array-valued arguments. Instead, `nondiff_argnums` should be used only for
|
||||
non-array values, like Python callables or shape tuples or strings.
|
||||
|
||||
Wherever we used to use `nondiff_argnums` for array values, we should just pass
|
||||
those as regular arguments. In the `bwd` rule, we need to produce values for them,
|
||||
but we can just produce `None` values to indicate there's no corresponding
|
||||
gradient value.
|
||||
|
||||
For example, here's the **old** way to write `clip_gradient`, which won't work
|
||||
when `hi` and/or `lo` are `Tracer`s from some JAX transformation.
|
||||
|
||||
```python
|
||||
from functools import partial
|
||||
import jax
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
|
||||
def clip_gradient(lo, hi, x):
|
||||
return x # identity function
|
||||
|
||||
def clip_gradient_fwd(lo, hi, x):
|
||||
return x, None # no residual values to save
|
||||
|
||||
def clip_gradient_bwd(lo, hi, _, g):
|
||||
return (jnp.clip(g, lo, hi),)
|
||||
|
||||
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
|
||||
```
|
||||
|
||||
Here's the **new**, awesome way, which supports arbitrary transformations:
|
||||
|
||||
```python
|
||||
import jax
|
||||
|
||||
@jax.custom_vjp # no nondiff_argnums!
|
||||
def clip_gradient(lo, hi, x):
|
||||
return x # identity function
|
||||
|
||||
def clip_gradient_fwd(lo, hi, x):
|
||||
return x, (lo, hi) # save lo and hi values as residuals
|
||||
|
||||
def clip_gradient_bwd(res, g):
|
||||
lo, hi = res
|
||||
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi
|
||||
|
||||
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
|
||||
```
|
||||
|
||||
If you use the old way instead of the new way, you'll get a loud error in any
|
||||
case where something might go wrong (namely when there's a `Tracer` passed into
|
||||
a `nondiff_argnums` argument).
|
||||
|
||||
Here's a case where we actually need `nondiff_argnums` with `custom_vjp`:
|
||||
|
||||
```python
|
||||
from functools import partial
|
||||
import jax
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums)
|
||||
def skip_app(f, x):
|
||||
return f(x)
|
||||
|
||||
def skip_app_fwd(f, x):
|
||||
return skip_bwd(f, x), None
|
||||
|
||||
def skip_app_bwd(f, _, g):
|
||||
return (g,)
|
||||
|
||||
skip_app.defvjp(skip_app_fwd, skip_app_bwd)
|
||||
```
|
||||
|
||||
|
||||
### Explanation
|
||||
|
||||
Passing `Tracer`s into `nondiff_argnums` arguments was always buggy. While there
|
||||
were some cases which worked correctly, others would lead to complex and
|
||||
confusing error messages.
|
||||
|
||||
The essence of the bug was that `nondiff_argnums` was implemented in a way that
|
||||
acted very much like lexical closure. But lexical closure over `Tracer`s wasn't
|
||||
at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing
|
||||
`nondiff_argnums` that way was a mistake!
|
||||
|
||||
**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure
|
||||
issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp`
|
||||
and `custom_vjp` functions and rules can close over `Tracer`s to our hearts'
|
||||
content. For all non-autodiff transformations, things will Just Work. For
|
||||
autodiff transformations, we'll get a clear error message about why we can't
|
||||
differentiate with respect to values over which a `custom_jvp` or `custom_vjp`
|
||||
closes:
|
||||
|
||||
> Detected differentiation of a custom_jvp function with respect to a closed-over
|
||||
value. That isn't supported because the custom JVP rule only specifies how to
|
||||
differentiate the custom_jvp function with respect to explicit input parameters.
|
||||
>
|
||||
> Try passing the closed-over value into the custom_jvp function as an argument,
|
||||
and adapting the custom_jvp rule.
|
||||
|
||||
In tightening up and robustifying `custom_jvp` and `custom_vjp` in this way, we
|
||||
found that allowing `custom_vjp` to accept `Tracer`s in its `nondiff_argnums`
|
||||
would take a significant amount of bookkeeping: we'd need to rewrite the user's
|
||||
`fwd` function to return the values as residuals, and rewrite the user's `bwd`
|
||||
function to accept them as normal residuals (rather than accepting them as
|
||||
special leading arguments, as happens with `nondiff_argnums`). This seems maybe
|
||||
manageable, until you think through how we have to handle arbitrary pytrees!
|
||||
Moreover, that complexity isn't necessary: if user code treats array-like
|
||||
non-differentiable arguments just like regular arguments and residuals,
|
||||
everything already works. (Before
|
||||
[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about
|
||||
involving integer-valued inputs and outputs in autodiff, but after
|
||||
[#4039](https://github.com/google/jax/pull/4039) those will just work!)
|
||||
|
||||
Unlike `custom_vjp`, it was easy to make `custom_jvp` work with
|
||||
`nondiff_argnums` arguments that were `Tracer`s. So these updates only need to
|
||||
happen with `custom_vjp`.
|
File diff suppressed because one or more lines are too long
31
jax/core.py
31
jax/core.py
@ -393,23 +393,24 @@ class Trace:
|
||||
self.__class__.__name__, self.level, self.sublevel)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
raise NotImplementedError("must override to handle call-like primitives")
|
||||
msg = (f"{type(self)} must override process_call to handle call-like "
|
||||
"primitives")
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def process_map(self, call_primitive, f, tracers, params):
|
||||
raise NotImplementedError("must override to handle map-like primitives")
|
||||
msg = (f"{type(self)} must override process_map to handle map-like "
|
||||
"primitives")
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
||||
# As a default implementation, drop the custom differentiation rule. This
|
||||
# behavior is desirable when staging out of the JAX system, but not when
|
||||
# there are further differentiation transformations to be applied. Override
|
||||
# this method to allow differentiation to be performed downstream.
|
||||
del primitive, jvp # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
msg = (f"{type(self)} must override process_custom_jvp_call "
|
||||
"to handle custom_jvp primitives")
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
||||
# See comment in the above process_custom_jvp_call method.
|
||||
del primitive, fwd, bwd, out_trees # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
msg = (f"{type(self)} must override process_custom_vjp_call "
|
||||
"to handle custom_vjp primitives")
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def escaped_tracer_error(detail=None):
|
||||
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
|
||||
@ -575,6 +576,14 @@ class EvalTrace(Trace):
|
||||
return primitive.impl(f, *tracers, **params)
|
||||
process_map = process_call
|
||||
|
||||
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
||||
del primitive, jvp # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
||||
del primitive, fwd, bwd, out_trees # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
|
||||
class MainTrace:
|
||||
level: int
|
||||
|
@ -16,6 +16,7 @@
|
||||
from functools import update_wrapper, reduce, partial
|
||||
import inspect
|
||||
import operator as op
|
||||
from typing import Callable, Sequence, Tuple, Any
|
||||
|
||||
from . import core
|
||||
from . import linear_util as lu
|
||||
@ -45,38 +46,12 @@ def _resolve_kwargs(fun, args, kwargs):
|
||||
else:
|
||||
return ba.args
|
||||
|
||||
def _add_args(f, extra_args, left):
|
||||
return _add_args_(f, tuple(map(wrap_hashably, extra_args)), left)
|
||||
|
||||
@lu.transformation
|
||||
def _add_args_(extra_args, left, *args, **kwargs):
|
||||
extra_args = tuple([arg.val for arg in extra_args])
|
||||
args = (extra_args + args) if left else (args + extra_args)
|
||||
yield (yield args, kwargs)
|
||||
|
||||
def _memoize(thunk):
|
||||
cell = []
|
||||
saved_state = core.thread_local_state.trace_state.copy()
|
||||
def memoized():
|
||||
if not cell:
|
||||
prev_state = core.thread_local_state.trace_state
|
||||
core.thread_local_state.trace_state = saved_state
|
||||
try:
|
||||
cell.append(thunk())
|
||||
finally:
|
||||
core.thread_local_state.trace_state = prev_state
|
||||
return cell[0]
|
||||
return memoized
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
return jaxpr, consts
|
||||
|
||||
def _initial_style_staging() -> bool:
|
||||
if config.omnistaging_enabled:
|
||||
return core.thread_local_state.trace_state.trace_stack.dynamic.level != 0 # type: ignore
|
||||
else:
|
||||
return core.thread_local_state.trace_state.initial_style
|
||||
return core.thread_local_state.trace_state.initial_style
|
||||
|
||||
def _sum_tangents(_, x, *xs):
|
||||
return reduce(ad.add_tangents, xs, x)
|
||||
@ -208,27 +183,39 @@ class custom_jvp:
|
||||
raise AttributeError(msg.format(self.__name__))
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
if self.nondiff_argnums:
|
||||
is_nondiff = [False] * len(args)
|
||||
for i in self.nondiff_argnums: is_nondiff[i] = True
|
||||
args = [_stop_gradient(x) if b else x for b, x in zip(is_nondiff, args)]
|
||||
dyn_argnums = [i for i, b in enumerate(is_nondiff) if not b]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
|
||||
args = [_stop_gradient(x) if i in self.nondiff_argnums else x
|
||||
for i, x in enumerate(args)]
|
||||
diff_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args)
|
||||
static_args = [args[i] for i in self.nondiff_argnums]
|
||||
jvp = _add_args(lu.wrap_init(self.jvp), static_args, left=True)
|
||||
jvp = _add_args(lu.wrap_init(self.jvp), static_args)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
jvp = lu.wrap_init(self.jvp)
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
|
||||
if _initial_style_staging():
|
||||
out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat)
|
||||
out_tree = out_tree1()
|
||||
else:
|
||||
out_flat = custom_jvp_call(flat_fun, flat_jvp, *args_flat)
|
||||
if config.omnistaging_enabled:
|
||||
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
|
||||
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
|
||||
else:
|
||||
if _initial_style_staging():
|
||||
out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat) # type: ignore
|
||||
out_tree = out_tree1()
|
||||
else:
|
||||
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
|
||||
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
def _add_args(f, extra_args):
|
||||
return _add_args_(f, tuple(map(wrap_hashably, extra_args)))
|
||||
|
||||
@lu.transformation
|
||||
def _add_args_(extra_args, *args, **kwargs):
|
||||
extra_args = tuple([arg.val for arg in extra_args])
|
||||
all_args = (extra_args + args)
|
||||
yield (yield all_args, kwargs)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _flatten_jvp(in_tree, *args):
|
||||
primals_in, tangents_in = split_list(args, [len(args) // 2])
|
||||
@ -264,6 +251,8 @@ def _flatten_jvp(in_tree, *args):
|
||||
yield primals_out + tangents_out, out_tree
|
||||
|
||||
class CustomJVPCallPrimitive(core.CallPrimitive):
|
||||
initial_style: core.Primitive
|
||||
|
||||
def bind(self, fun, jvp, *args):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
@ -272,51 +261,63 @@ class CustomJVPCallPrimitive(core.CallPrimitive):
|
||||
jvp, env_trace_todo2 = core.process_env_traces(
|
||||
jvp, self, top_trace and top_trace.level, ())
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
|
||||
with core.maybe_new_sublevel(top_trace):
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
if env_trace_todo:
|
||||
raise core.UnexpectedTracerError
|
||||
return map(core.full_lower, outs)
|
||||
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
|
||||
|
||||
def impl(self, fun, _, *args):
|
||||
return fun.call_wrapped(*args)
|
||||
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
return trace.post_process_custom_jvp_call(out_tracers, params)
|
||||
|
||||
def _apply_todos(todos, outs):
|
||||
todos_list = list(todos)
|
||||
while todos_list:
|
||||
outs = map(core.full_lower, todos_list.pop()(outs))
|
||||
return outs
|
||||
|
||||
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
|
||||
custom_jvp_call = custom_jvp_call_p.bind
|
||||
|
||||
|
||||
def custom_jvp_call_jaxpr(fun, jvp, *args):
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
|
||||
jvp_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
|
||||
return custom_jvp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk)
|
||||
|
||||
def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr, **_):
|
||||
def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params):
|
||||
del params # other params ignored because we're just executing the primal fun
|
||||
return core.jaxpr_as_fun(fun_jaxpr)(*args)
|
||||
|
||||
def _custom_jvp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
|
||||
def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params):
|
||||
del args, params
|
||||
return fun_jaxpr.out_avals
|
||||
|
||||
custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr')
|
||||
custom_jvp_call_jaxpr_p.multiple_results = True
|
||||
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
|
||||
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
|
||||
CustomJVPCallPrimitive.initial_style = custom_jvp_call_jaxpr_p
|
||||
|
||||
def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, jvp_jaxpr_thunk):
|
||||
jvp_jaxpr = jvp_jaxpr_thunk()
|
||||
tangents = map(ad.instantiate_zeros, tangents)
|
||||
def _custom_jvp_call_jaxpr_jvp(
|
||||
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int):
|
||||
_, args = split_list(primals, [num_consts])
|
||||
consts_dot, args_dot = split_list(tangents, [num_consts])
|
||||
if any(type(t) is not Zero for t in consts_dot):
|
||||
raise ad.CustomJVPException()
|
||||
jvp_jaxpr, jvp_consts = jvp_jaxpr_thunk() # consts can be tracers!
|
||||
args_dot = map(ad.instantiate_zeros, args_dot)
|
||||
# Cast float0 to zeros with the primal dtype because custom jvp rules don't
|
||||
# currently handle float0s
|
||||
tangents = ad.replace_float0s(primals, tangents)
|
||||
outs = core.jaxpr_as_fun(jvp_jaxpr)(*primals, *tangents)
|
||||
args_dot = map(ad.replace_float0s, args, args_dot)
|
||||
outs = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *args, *args_dot)
|
||||
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
||||
tangents_out = ad.recast_to_float0(primals_out, tangents_out)
|
||||
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
|
||||
return primals_out, tangents_out
|
||||
|
||||
ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp
|
||||
|
||||
def _custom_jvp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, jvp_jaxpr_thunk):
|
||||
def _custom_jvp_call_jaxpr_vmap(
|
||||
args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
num_consts: int):
|
||||
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
|
||||
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||
else x for x, d in zip(args, in_dims)]
|
||||
@ -325,21 +326,22 @@ def _custom_jvp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, jvp_jaxpr_thunk):
|
||||
in_batched = [d is not not_mapped for d in in_dims]
|
||||
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(fun_jaxpr, size, in_batched, False)
|
||||
out_dims1 = [0 if b else not_mapped for b in out_batched]
|
||||
out_dims2 = []
|
||||
out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk
|
||||
|
||||
@_memoize
|
||||
@pe._memoize
|
||||
def batched_jvp_jaxpr_thunk():
|
||||
jvp_jaxpr = jvp_jaxpr_thunk()
|
||||
jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
|
||||
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, in_batched * 2, False)
|
||||
primals_batched, tangents_batched = split_list(all_batched, [num_out])
|
||||
out_batched = map(op.or_, primals_batched, tangents_batched)
|
||||
out_dims2.append([0 if b else not_mapped for b in out_batched])
|
||||
batched_jvp_jaxpr, _ = batching.batch_jaxpr(jvp_jaxpr, size, in_batched * 2,
|
||||
out_batched * 2)
|
||||
return batched_jvp_jaxpr
|
||||
batched_jvp_jaxpr, _ = batching.batch_jaxpr(
|
||||
jvp_jaxpr, size, in_batched * 2, out_batched * 2)
|
||||
return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts
|
||||
|
||||
batched_outs = custom_jvp_call_jaxpr_p.bind(
|
||||
*args, fun_jaxpr=batched_fun_jaxpr, jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk)
|
||||
*args, fun_jaxpr=batched_fun_jaxpr,
|
||||
jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts)
|
||||
out_dims = out_dims2[0] if out_dims2 else out_dims1
|
||||
return batched_outs, out_dims
|
||||
batching.primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
|
||||
@ -350,8 +352,9 @@ xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
|
||||
# If a (multi)linear function is defined with a custom jvp, then
|
||||
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
|
||||
# already been linearized, we can drop the jvp rule.
|
||||
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk):
|
||||
del jvp_jaxpr_thunk
|
||||
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
|
||||
num_consts):
|
||||
del jvp_jaxpr_thunk, num_consts
|
||||
return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)
|
||||
ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
|
||||
|
||||
@ -451,14 +454,12 @@ class custom_vjp:
|
||||
raise AttributeError(msg.format(self.__name__))
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
if self.nondiff_argnums:
|
||||
is_nondiff = [False] * len(args)
|
||||
for i in self.nondiff_argnums: is_nondiff[i] = True
|
||||
args = [_stop_gradient(x) if b else x for b, x in zip(is_nondiff, args)]
|
||||
dyn_argnums = [i for i, b in enumerate(is_nondiff) if not b]
|
||||
for i in self.nondiff_argnums: _check_for_tracers(args[i])
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
|
||||
static_args = [args[i] for i in self.nondiff_argnums]
|
||||
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
|
||||
bwd = _add_args(lu.wrap_init(self.bwd), static_args, left=True)
|
||||
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
|
||||
@ -467,17 +468,39 @@ class custom_vjp:
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
|
||||
if _initial_style_staging():
|
||||
out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees)
|
||||
out_tree = out_tree()
|
||||
else:
|
||||
out_flat = custom_vjp_call(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees)
|
||||
if config.omnistaging_enabled:
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
|
||||
out_trees=out_trees)
|
||||
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
|
||||
out_tree = aux if fst else aux[0]
|
||||
else:
|
||||
if _initial_style_staging():
|
||||
out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd, # type: ignore
|
||||
*args_flat, out_trees=out_trees)
|
||||
out_tree = out_tree()
|
||||
else:
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees)
|
||||
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
|
||||
out_tree = aux if fst else aux[0]
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
@partial(partial, tree_map)
|
||||
def _check_for_tracers(x):
|
||||
if isinstance(x, core.Tracer):
|
||||
msg = ("Found a JAX Tracer object passed as an argument to a custom_vjp "
|
||||
"function in a position indicated by nondiff_argnums as "
|
||||
"non-differentiable. Tracers cannot be passed as non-differentiable "
|
||||
"arguments to custom_vjp functions; instead, nondiff_argnums should "
|
||||
"only be used for arguments that can't be or contain JAX tracers, "
|
||||
"e.g. function-valued arguments. In particular, array-valued "
|
||||
"arguments should typically not be indicated as nondiff_argnums. "
|
||||
"\n\n"
|
||||
"This behavior recently changed in JAX. "
|
||||
"See https://github.com/google/jax/blob/master/docs/custom_vjp_update.md "
|
||||
"for more information.")
|
||||
raise core.UnexpectedTracerError(msg)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _flatten_fwd(in_tree, *args):
|
||||
py_args = tree_unflatten(in_tree, args)
|
||||
@ -524,31 +547,29 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
|
||||
|
||||
class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
initial_style: core.Primitive
|
||||
|
||||
def bind(self, fun, fwd, bwd, *args, out_trees):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
if top_trace is None:
|
||||
outs = fun.call_wrapped(*args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
fun, env_trace_todo1 = core.process_env_traces(
|
||||
fun, self, top_trace and top_trace.level, ())
|
||||
fwd, env_trace_todo2 = core.process_env_traces(
|
||||
fwd, self, top_trace and top_trace.level, ())
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
with core.maybe_new_sublevel(top_trace):
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
|
||||
out_trees=out_trees)
|
||||
return map(core.full_lower, outs)
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
return _apply_todos(env_trace_todo, map(core.full_lower, outs))
|
||||
|
||||
def impl(self, fun, fwd, bwd, *args, out_trees):
|
||||
del fwd, bwd, out_trees
|
||||
return fun.call_wrapped(*args)
|
||||
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
return trace.post_process_custom_vjp_call(out_tracers, params)
|
||||
custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call')
|
||||
custom_vjp_call = custom_vjp_call_p.bind
|
||||
|
||||
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
|
||||
fwd_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(fwd, in_avals))
|
||||
return custom_vjp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd,
|
||||
out_trees=out_trees)
|
||||
|
||||
def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
|
||||
return core.jaxpr_as_fun(fun_jaxpr)(*args)
|
||||
@ -560,27 +581,35 @@ custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr')
|
||||
custom_vjp_call_jaxpr_p.multiple_results = True
|
||||
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
|
||||
custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
|
||||
CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p
|
||||
|
||||
def _custom_vjp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, fwd_jaxpr_thunk,
|
||||
bwd, out_trees):
|
||||
tangents = map(ad.instantiate_zeros, tangents)
|
||||
def _custom_vjp_call_jaxpr_jvp(
|
||||
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
|
||||
_, args = split_list(primals, [num_consts])
|
||||
consts_dot, args_dot = split_list(tangents, [num_consts])
|
||||
if any(type(t) is not Zero for t in consts_dot):
|
||||
raise ad.CustomVJPException()
|
||||
fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk() # consts can be tracers!
|
||||
out_tree, res_tree = out_trees()
|
||||
args_dot = map(ad.instantiate_zeros, args_dot)
|
||||
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
|
||||
# currently handle float0s
|
||||
tangents = ad.replace_float0s(primals, tangents)
|
||||
fwd_jaxpr = fwd_jaxpr_thunk()
|
||||
out_tree, res_tree = out_trees()
|
||||
res_and_primals_out = core.jaxpr_as_fun(fwd_jaxpr)(*primals)
|
||||
args_dot = map(ad.replace_float0s, args, args_dot)
|
||||
res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
|
||||
tangents_out = ad.custom_lin_p.bind(
|
||||
*res, *tangents, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)
|
||||
tangents_out = ad.recast_to_float0(primals_out, tangents_out)
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out)
|
||||
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
|
||||
return primals_out, tangents_out
|
||||
|
||||
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
|
||||
def _custom_vjp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, fwd_jaxpr_thunk,
|
||||
bwd, out_trees):
|
||||
def _custom_vjp_call_jaxpr_vmap(
|
||||
args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
|
||||
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
|
||||
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||
else x for x, d in zip(args, in_dims)]
|
||||
@ -590,25 +619,26 @@ def _custom_vjp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, fwd_jaxpr_thunk,
|
||||
out_dims1 = [0 if b else not_mapped for b in out_batched]
|
||||
out_dims2 = []
|
||||
|
||||
@_memoize
|
||||
@pe._memoize
|
||||
def batched_fwd_jaxpr_thunk():
|
||||
fwd_jaxpr = fwd_jaxpr_thunk()
|
||||
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
|
||||
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(fwd_jaxpr, size, in_batched, False)
|
||||
out_dims2.append([0 if b else not_mapped for b in out_batched])
|
||||
return batched_fwd_jaxpr
|
||||
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts
|
||||
|
||||
fwd_in_dims = [0 if b else not_mapped for b in in_batched]
|
||||
fwd_out_dims = lambda: out_dims2[0]
|
||||
# TODO: Support collectives in custom_vjp?
|
||||
# TODO(mattjj,apaszke): Support collectives in custom_vjp?
|
||||
batched_bwd = batching.batch_fun(bwd, fwd_out_dims, fwd_in_dims,
|
||||
axis_name='__unused_axis_name', sum_match=True)
|
||||
|
||||
batched_outs = custom_vjp_call_jaxpr_p.bind(
|
||||
*args, fun_jaxpr=batched_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
|
||||
out_trees=out_trees)
|
||||
out_trees=out_trees, num_consts=num_consts)
|
||||
out_dims = out_dims2[0] if out_dims2 else out_dims1
|
||||
out_dims = out_dims[:len(batched_outs)] # TODO(mattjj): remove after #4008
|
||||
if not config.omnistaging_enabled:
|
||||
out_dims = out_dims[:len(batched_outs)]
|
||||
return batched_outs, out_dims
|
||||
batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
|
||||
|
||||
@ -620,16 +650,16 @@ batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_jaxpr, custom_jvp_call
|
||||
global _initial_style_jaxpr, custom_vjp_call_jaxpr, custom_jvp_call_jaxpr
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
|
||||
bottom=True, stage_out=False) # type: ignore
|
||||
assert not any(isinstance(c, core.Tracer) for c in consts)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
return jaxpr, consts
|
||||
|
||||
def bind(self, fun, jvp, *args):
|
||||
def jvp_bind(self, fun, jvp, *args):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
fun, env_trace_todo1 = core.process_env_traces(
|
||||
@ -646,5 +676,43 @@ def omnistaging_disabler() -> None:
|
||||
if env_trace_todo:
|
||||
raise core.UnexpectedTracerError
|
||||
return map(core.full_lower, outs)
|
||||
CustomJVPCallPrimitive.bind = bind # type: ignore
|
||||
custom_jvp_call = custom_jvp_call_p.bind
|
||||
CustomJVPCallPrimitive.bind = jvp_bind # type: ignore
|
||||
|
||||
def jvp_post_process(self, trace, out_tracers, params):
|
||||
raise core.UnexpectedTracerError
|
||||
CustomJVPCallPrimitive.post_process = jvp_post_process # type: ignore
|
||||
|
||||
def vjp_bind(self, fun, fwd, bwd, *args, out_trees):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
if top_trace is None:
|
||||
outs = fun.call_wrapped(*args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
|
||||
out_trees=out_trees)
|
||||
return map(core.full_lower, outs)
|
||||
CustomVJPCallPrimitive.bind = vjp_bind # type: ignore
|
||||
|
||||
def vjp_post_process(self, trace, out_tracers, params):
|
||||
raise core.UnexpectedTracerError
|
||||
CustomVJPCallPrimitive.post_process = vjp_post_process # type: ignore
|
||||
|
||||
def custom_jvp_call_jaxpr(fun: Callable, jvp: Callable, *args):
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
jvp_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
|
||||
return custom_jvp_call_jaxpr_p.bind(
|
||||
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
|
||||
|
||||
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
fwd_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(fwd, in_avals))
|
||||
return custom_vjp_call_jaxpr_p.bind(
|
||||
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
|
||||
num_consts=len(consts))
|
||||
|
@ -560,6 +560,26 @@ class TensorFlowTrace(core.Trace):
|
||||
def post_process_map(self, map_primitive, out_tracers, params):
|
||||
raise NotImplementedError("post_process_map")
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
# Drop the custom differentiation rule and act like a call primitive. This
|
||||
# behavior is desirable because jax2tf stages code out of the JAX system, so
|
||||
# there are no more JAX differentiation transformations to be applied.
|
||||
del jvp # Unused.
|
||||
return self.process_call(core.call_p, fun, tracers, {})
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, params):
|
||||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
# Drop the custom differentiation rule and act like a call primitive. This
|
||||
# behavior is desirable because jax2tf stages code out of the JAX system, so
|
||||
# there are no more JAX differentiation transformations to be applied.
|
||||
del fwd, bwd, out_trees # Unused.
|
||||
return self.process_call(core.call_p, fun, tracers, {})
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, params):
|
||||
assert False # unreachable assuming jax2tf runs with clean trace state
|
||||
|
||||
def get_primitive_impl(self, p: core.Primitive) -> Tuple[Callable, bool]:
|
||||
# Returns the primitive implementation and whether the implementation
|
||||
# takes abstract values (see definition of tf_impl_with_avals)
|
||||
@ -1808,7 +1828,8 @@ tf_impl[lax_linalg.triangular_solve_p] = _triangular_solve
|
||||
|
||||
def _custom_jvp_call_jaxpr(*args: TfVal,
|
||||
fun_jaxpr: core.ClosedJaxpr,
|
||||
jvp_jaxpr_thunk: Callable) -> Sequence[TfVal]:
|
||||
jvp_jaxpr_thunk: Callable,
|
||||
num_consts: int) -> Sequence[TfVal]:
|
||||
# TODO(necula): ensure that there is no AD transformation in scope
|
||||
return _interpret_jaxpr(fun_jaxpr, *args)
|
||||
|
||||
|
@ -153,8 +153,14 @@ class JetTrace(core.Trace):
|
||||
return map(partial(JetTracer, trace), primals, series)
|
||||
return out, todo
|
||||
|
||||
def join(self, xt, yt):
|
||||
assert False # TODO?
|
||||
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
|
||||
# TODO(mattjj): don't just ignore custom jvp rules?
|
||||
del primitive, jvp # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
|
||||
del primitive, fwd, bwd, out_trees # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
|
||||
class ZeroTerm(object): pass
|
||||
|
@ -146,16 +146,17 @@ def unpair_pval(pval):
|
||||
aval_1, aval_2 = aval
|
||||
return (aval_1, const_1), (aval_2, const_2)
|
||||
|
||||
def replace_float0s(primals, tangents):
|
||||
return [core.zeros_like_float0(tangent, dtype(primal))
|
||||
if dtype(tangent) is float0 else tangent
|
||||
for primal, tangent in zip(primals, tangents)]
|
||||
def replace_float0s(primal, tangent):
|
||||
if dtype(tangent) is float0:
|
||||
return core.zeros_like_float0(tangent, dtype(primal))
|
||||
else:
|
||||
return tangent
|
||||
|
||||
def recast_to_float0(primals, tangents):
|
||||
return [Zero(get_aval(primal).at_least_vspace())
|
||||
if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0
|
||||
else tangent
|
||||
for primal, tangent in zip(primals, tangents)]
|
||||
def recast_to_float0(primal, tangent):
|
||||
if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
|
||||
return Zero(get_aval(primal).at_least_vspace())
|
||||
else:
|
||||
return tangent
|
||||
|
||||
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
|
||||
def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
||||
@ -314,12 +315,15 @@ class JVPTrace(Trace):
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
# Cast float0 to zeros with the primal dtype because custom jvp rules don't
|
||||
# currently handle float0s
|
||||
tangents_in = replace_float0s(primals_in, tangents_in)
|
||||
tangents_in = map(replace_float0s, primals_in, tangents_in)
|
||||
outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
|
||||
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
||||
tangents_out = recast_to_float0(primals_out, tangents_out)
|
||||
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
||||
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, params):
|
||||
raise CustomJVPException()
|
||||
|
||||
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
|
||||
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
@ -330,9 +334,12 @@ class JVPTrace(Trace):
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
avals_out=avals_out)
|
||||
tangents_out = recast_to_float0(primals_out, tangents_out)
|
||||
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
||||
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, params):
|
||||
raise CustomVJPException()
|
||||
|
||||
def join(self, xt, yt):
|
||||
xz, yz = type(xt) is Zero, type(yt) is Zero
|
||||
if xz == yz:
|
||||
@ -625,8 +632,7 @@ def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
|
||||
res, _ = split_list(invals, [num_res])
|
||||
cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
|
||||
cts_in = bwd.call_wrapped(*res, *cts_out)
|
||||
cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure
|
||||
return [None] * num_res + cts_in_flat
|
||||
return [None] * num_res + list(cts_in)
|
||||
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
||||
|
||||
|
||||
@ -696,6 +702,28 @@ def defvjp2(prim, *vjps):
|
||||
defvjp_all(prim, vjpmaker)
|
||||
|
||||
|
||||
class CustomJVPException(Exception):
|
||||
def __init__(self):
|
||||
# TODO(mattjj): track source provenance on AD tracers, improve error
|
||||
msg = ("Detected differentiation of a custom_jvp function with respect to "
|
||||
"a closed-over value. That isn't supported because the custom JVP "
|
||||
"rule only specifies how to differentiate the custom_jvp function "
|
||||
"with respect to explicit input parameters. Try passing the "
|
||||
"closed-over value into the custom_jvp function as an argument, and "
|
||||
"adapting the custom_jvp rule.")
|
||||
super().__init__(msg)
|
||||
|
||||
class CustomVJPException(Exception):
|
||||
def __init__(self):
|
||||
# TODO(mattjj): track source provenance on AD tracers, improve error
|
||||
msg = ("Detected differentiation of a custom_vjp function with respect to "
|
||||
"a closed-over value. That isn't supported because the custom VJP "
|
||||
"rule only specifies how to differentiate the custom_vjp function "
|
||||
"with respect to explicit input parameters. Try passing the "
|
||||
"closed-over value into the custom_vjp function as an argument, and "
|
||||
"adapting the custom_vjp fwd and bwd rules.")
|
||||
super().__init__(msg)
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global jvp_jaxpr
|
||||
|
@ -215,11 +215,19 @@ class BatchTrace(Trace):
|
||||
out_dims = out_dims[:len(out_dims) // 2]
|
||||
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, params):
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
main = self.main
|
||||
def todo(vals):
|
||||
trace = BatchTrace(main, core.cur_sublevel())
|
||||
return map(partial(BatchTracer, trace), vals, dims)
|
||||
return vals, todo
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
||||
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
||||
# TODO: Support collectives in custom_vjp?
|
||||
# TODO(mattjj,apaszke): support collectives in custom_vjp?
|
||||
bwd = batch_fun(bwd, out_dims2, in_dims,
|
||||
axis_name='__unused_axis_name', sum_match=True)
|
||||
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
||||
@ -370,11 +378,11 @@ def _promote_aval_rank(sz, aval):
|
||||
else:
|
||||
return ShapedArray((sz,) + aval.shape, aval.dtype)
|
||||
|
||||
def batch_jaxpr(jaxpr, size, batched, instantiate):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
def batch_jaxpr(closed_jaxpr, size, batched, instantiate):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
||||
f, batched_out = batched_traceable(f, size, batched, instantiate)
|
||||
avals_in = [_promote_aval_rank(size, a) if b else a
|
||||
for a, b in zip(jaxpr.in_avals, batched)]
|
||||
for a, b in zip(closed_jaxpr.in_avals, batched)]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), batched_out()
|
||||
|
||||
|
@ -31,6 +31,10 @@ from ..config import config
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
################################################################################
|
||||
# Reverse call primitive
|
||||
################################################################################
|
||||
@ -118,10 +122,9 @@ def _flatten_ivjp(in_tree, out_tree, *args):
|
||||
|
||||
def _custom_ivjp(fun, ivjp, args):
|
||||
in_avals = [raise_to_shaped(get_aval(x)) for x in args]
|
||||
fun_jaxpr = custom_derivatives._initial_style_jaxpr(fun, in_avals)
|
||||
fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
|
||||
try:
|
||||
ivjp_jaxpr = custom_derivatives._initial_style_jaxpr(
|
||||
ivjp, in_avals + fun_jaxpr.out_avals * 2)
|
||||
ivjp_jaxpr = _initial_style_jaxpr(ivjp, in_avals + fun_jaxpr.out_avals * 2)
|
||||
except RecursionError:
|
||||
raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__))
|
||||
return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr,
|
||||
@ -310,3 +313,15 @@ def get_primitive_inverse(p):
|
||||
def definverse(primitive, inverse_rule):
|
||||
primitive_inverses[primitive] = inverse_rule
|
||||
return inverse_rule
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_jaxpr, custom_jvp_call
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
|
||||
bottom=True, stage_out=False) # type: ignore
|
||||
assert not any(isinstance(c, core.Tracer) for c in consts)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
@ -179,7 +179,7 @@ class JaxprTrace(Trace):
|
||||
else PartialVal.unknown(mapped_aval(pval[0]))
|
||||
for pval, is_mapped in zip(in_pvals, params['mapped_invars'])]
|
||||
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
|
||||
f, in_pvals, partial(primitive.bind, **params))
|
||||
f, in_pvals, partial(primitive.bind, **params), instantiate=False)
|
||||
if primitive.map_primitive:
|
||||
unmapped_aval = partial(core.unmapped_aval, params['axis_size'])
|
||||
out_pvals = [pval if pval.is_known()
|
||||
@ -273,10 +273,11 @@ class JaxprTrace(Trace):
|
||||
post_process_map = post_process_call
|
||||
|
||||
def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]]):
|
||||
app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]],
|
||||
instantiate: bool):
|
||||
"""Partially evaluate f on a sequence of PartialVals."""
|
||||
in_avals, in_consts = unzip2(pvals)
|
||||
f = trace_to_subjaxpr(f, self.main, False)
|
||||
f = trace_to_subjaxpr(f, self.main, instantiate)
|
||||
f, aux = partial_eval_wrapper(f, tuple(in_avals))
|
||||
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
|
||||
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
@ -284,6 +285,85 @@ class JaxprTrace(Trace):
|
||||
env_tracers = map(self.full_raise, env)
|
||||
return jaxpr, out_pvs, consts, env_tracers
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
|
||||
fun = trace_to_subjaxpr(fun, self.main, True)
|
||||
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
|
||||
out_flat = prim.bind(fun, jvp, *in_consts)
|
||||
out_avals, jaxpr, env = aux()
|
||||
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
|
||||
env_tracers = map(self.full_raise, env)
|
||||
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
|
||||
const_tracers = map(self.new_instantiated_const, consts)
|
||||
in_tracers = (*const_tracers, *env_tracers, *tracers)
|
||||
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
@_memoize
|
||||
def jvp_jaxpr_thunk():
|
||||
jvp_ = trace_to_subjaxpr(jvp, self.main, True)
|
||||
jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2)
|
||||
out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units
|
||||
out_avals, jaxpr, env = aux()
|
||||
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
||||
return converted_jaxpr, (*consts, *env)
|
||||
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
||||
num_consts=len(consts) + len(env)),
|
||||
source_info_util.current())
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, params):
|
||||
# This path should only be reachable if we expose a partial eval API
|
||||
# unrelated to autodiff, since we raise an error when differentiation with
|
||||
# respect to values over which a custom_jvp function closes is detected.
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
|
||||
fun = trace_to_subjaxpr(fun, self.main, True)
|
||||
fun, aux = partial_eval_wrapper(fun, tuple(in_avals))
|
||||
out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees)
|
||||
out_avals, jaxpr, env = aux()
|
||||
out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units
|
||||
env_tracers = map(self.full_raise, env)
|
||||
out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals]
|
||||
const_tracers = map(self.new_instantiated_const, consts)
|
||||
in_tracers = (*const_tracers, *env_tracers, *tracers)
|
||||
closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
@_memoize
|
||||
def fwd_jaxpr_thunk():
|
||||
fwd_ = trace_to_subjaxpr(fwd, self.main, True)
|
||||
fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals))
|
||||
out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units
|
||||
out_avals, jaxpr, env = aux()
|
||||
_, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env))
|
||||
return converted_jaxpr, (*consts, *env)
|
||||
|
||||
eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
||||
num_consts=len(consts) + len(env),
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
source_info_util.current())
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return out_tracers
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, params):
|
||||
# This path should only be reachable if we expose a partial eval API
|
||||
# unrelated to autodiff, since we raise an error when differentiation with
|
||||
# respect to values over which a custom_vjp function closes is detected.
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts):
|
||||
@ -551,6 +631,15 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr):
|
||||
core.skip_checks or core.check_jaxpr(lifted_jaxpr)
|
||||
return lifted_jaxpr
|
||||
|
||||
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int):
|
||||
core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
|
||||
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
|
||||
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
core.skip_checks or core.check_jaxpr(converted_jaxpr)
|
||||
return converted_jaxpr
|
||||
|
||||
|
||||
def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
|
||||
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
||||
|
||||
@ -658,11 +747,11 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
in_pvals = [t.pval for t in instantiated_tracers]
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params))
|
||||
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
|
||||
else:
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params))
|
||||
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
|
||||
|
||||
# Convert consts to inputs, since they may contain Tracer instances.
|
||||
jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
@ -847,7 +936,6 @@ class JaxprStackFrame:
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
# core.skip_checks or core.check_jaxpr(jaxpr)
|
||||
out_avals = [t.aval for t in out_tracers]
|
||||
return jaxpr, out_avals, constvals
|
||||
|
||||
@ -1000,6 +1088,63 @@ class DynamicJaxprTrace(core.Trace):
|
||||
def post_process_map(self, map_primitive, out_tracers, params):
|
||||
assert False # unreachable
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
jvp_jaxpr_thunk = _memoize(
|
||||
lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2])
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
outvars = map(self.getvar, out_tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_fun_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
|
||||
num_consts=len(consts)),
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, params):
|
||||
assert False # unreachable
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
in_avals = [t.aval for t in tracers]
|
||||
fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
fwd_jaxpr_thunk = _memoize(
|
||||
lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2])
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
outvars = map(self.getvar, out_tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style,
|
||||
dict(fun_jaxpr=closed_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
|
||||
num_consts=len(consts),
|
||||
bwd=bwd, out_trees=out_trees),
|
||||
source_info_util.current())
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, params):
|
||||
assert False # unreachable
|
||||
|
||||
def _memoize(thunk):
|
||||
cell = []
|
||||
saved_state = core.thread_local_state.trace_state.copy()
|
||||
def memoized():
|
||||
if not cell:
|
||||
prev_state = core.thread_local_state.trace_state
|
||||
core.thread_local_state.trace_state = saved_state
|
||||
try:
|
||||
cell.append(thunk())
|
||||
finally:
|
||||
core.thread_local_state.trace_state = prev_state
|
||||
return cell[0]
|
||||
return memoized
|
||||
|
||||
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
||||
assert config.omnistaging_enabled
|
||||
|
@ -2793,10 +2793,30 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
expected = 2.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.jit(foo))(3.)
|
||||
expected = 2.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jit(api.grad(foo))(3.)
|
||||
expected = 2.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.grad(foo))(3.)
|
||||
expected = 0.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.grad(api.jit(foo)))(3.)
|
||||
expected = 0.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.jit(api.grad(foo)))(3.)
|
||||
expected = 0.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jit(api.grad(api.grad(foo)))(3.)
|
||||
expected = 0.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_initial_style_vmap(self):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
@ -2816,11 +2836,38 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
expected = 3. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.vmap(api.jit(foo))(jnp.ones(3))
|
||||
expected = 3. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jit(api.vmap(foo))(jnp.ones(3))
|
||||
expected = 3. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
||||
expected = 2. * jnp.ones(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_closed_over_tracers_error_message(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
@api.custom_jvp
|
||||
def g(y):
|
||||
@ -2830,10 +2877,8 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
g.defjvp(g_jvp)
|
||||
return g(1.)
|
||||
|
||||
self.assertRaises(
|
||||
core.UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,)))
|
||||
self.assertRaises(
|
||||
core.UnexpectedTracerError, lambda: api.grad(f)(3.))
|
||||
self.assertRaises(ad.CustomJVPException, lambda: api.jvp(f, (3.,), (1.,)))
|
||||
self.assertRaises(ad.CustomJVPException, lambda: api.grad(f)(3.))
|
||||
|
||||
def test_nondiff_arg(self):
|
||||
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
||||
@ -2869,6 +2914,25 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
expected = (6., 5.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_hiding_jvp_tracer(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
||||
def g(h, x):
|
||||
return h(x)
|
||||
@g.defjvp
|
||||
def g_jvp(h, primals, tangents):
|
||||
x, = primals
|
||||
t, = tangents
|
||||
return g(h, x), 2. * t
|
||||
h = lambda y: x + y # capture x
|
||||
return g(h, x)
|
||||
|
||||
with self.assertRaisesRegex(ad.CustomJVPException, "Detected differentiation"):
|
||||
api.jvp(f, (2.,), (1.,))
|
||||
|
||||
def test_vmap_axes(self):
|
||||
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
||||
|
||||
@ -3123,6 +3187,77 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
for ans in results:
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def test_nondiff_argnums_vmap_tracer(self):
|
||||
# https://github.com/google/jax/issues/3964
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
@partial(jax.custom_jvp, nondiff_argnums=(0, 2))
|
||||
def sample(shape, param, seed):
|
||||
return jax.random.uniform(key=seed, shape=shape, minval=param)
|
||||
|
||||
@sample.defjvp
|
||||
def sample_jvp(shape, seed, primals, tangents):
|
||||
param, = primals
|
||||
dparam, = tangents
|
||||
dparam = jnp.broadcast_to(dparam, shape)
|
||||
samples = sample(shape, param, seed)
|
||||
return samples, samples * dparam # dummy jvp for proof of concept
|
||||
|
||||
# check these don't crash
|
||||
jax.vmap(lambda seed: sample((2,3), 1., seed))(
|
||||
jax.random.split(jax.random.PRNGKey(1), 10))
|
||||
jax.jvp(lambda x: sample((2, 3), x, jax.random.PRNGKey(1)),
|
||||
(1.,), (1.,))
|
||||
|
||||
def test_fun_with_nested_calls_2(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def call(f, *args):
|
||||
f = api.custom_jvp(f)
|
||||
f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents)))
|
||||
return f(*args)
|
||||
|
||||
def fun_with_nested_calls_2(x):
|
||||
def bar(y):
|
||||
def baz(w):
|
||||
q = call(lambda x: y, x)
|
||||
q = q + call(lambda: y)
|
||||
q = q + call(lambda y: w + y, y)
|
||||
q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q
|
||||
return q
|
||||
return api.jit(baz)(x)
|
||||
return call(bar, x)
|
||||
|
||||
# test these don't crash
|
||||
self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.),
|
||||
fun_with_nested_calls_2(3.))
|
||||
api.vmap(fun_with_nested_calls_2)(jnp.arange(3.))
|
||||
|
||||
def test_closure_with_vmap(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
# https://github.com/google/jax/issues/3822
|
||||
alpha = np.float32(2.)
|
||||
|
||||
def sample(seed):
|
||||
@api.custom_jvp
|
||||
def f(alpha):
|
||||
return jax.random.gamma(seed, alpha, shape=[])
|
||||
|
||||
@f.defjvp
|
||||
def f_jvp(primal, tangent):
|
||||
alpha = primal
|
||||
dalpha = tangent
|
||||
sample = f(alpha)
|
||||
partial_alpha = lax.random_gamma_grad(alpha, sample)
|
||||
return sample, partial_alpha * dalpha
|
||||
return f(alpha)
|
||||
|
||||
api.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3)) # don't crash
|
||||
|
||||
def test_float0(self):
|
||||
@api.custom_jvp
|
||||
def f(x, y):
|
||||
@ -3157,6 +3292,53 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(api.jvp(foo, primals, tangents),
|
||||
(primals, expected_tangents))
|
||||
|
||||
def test_remat(self):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
def f_jvp(primals, tangents):
|
||||
x, = primals
|
||||
g, = tangents
|
||||
return f(x), 2 * jnp.cos(x) * g
|
||||
f.defjvp(f_jvp)
|
||||
|
||||
@api.remat
|
||||
def g(x):
|
||||
return f(f(x))
|
||||
|
||||
ans = g(2.)
|
||||
expected = np.sin(np.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(g)(2.)
|
||||
expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_higher_order(self):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
def f_jvp(primals, tangents):
|
||||
x, = primals
|
||||
g, = tangents
|
||||
return f(x), 2 * jnp.cos(x) * g
|
||||
f.defjvp(f_jvp)
|
||||
|
||||
def g(x):
|
||||
return f(f(x))
|
||||
|
||||
ans = api.grad(api.grad(api.remat(g)))(2.)
|
||||
expected = api.grad(api.grad(g))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.remat(api.grad(g)))(2.)
|
||||
expected = api.grad(api.grad(g))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.grad(api.grad(api.remat(g))))(2.)
|
||||
expected = api.grad(api.grad(api.grad(g)))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
@ -3395,6 +3577,12 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_tracer(self):
|
||||
# This test is now skipped because we decided not to support this behavior
|
||||
# anymore (namely, nondiff args can't be tracers), but
|
||||
# test_closed_over_tracer is a replacement test for analogous behavior that
|
||||
# we do support
|
||||
raise unittest.SkipTest("removed support for tracers in nondiff args")
|
||||
|
||||
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
||||
def f(x, y):
|
||||
return x * y
|
||||
@ -3416,6 +3604,56 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
expected = jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_closed_over_tracer(self):
|
||||
# This test is similar to test_nondiff_arg_tracer except it uses lexical
|
||||
# closure rather than the nondiff_argnums mechanism. We decided to disallow
|
||||
# tracers in nondiff_argnums to greatly simplify bookkeeping while still
|
||||
# supporting the cases for which it is necessary.
|
||||
def outer(x):
|
||||
@api.custom_vjp
|
||||
def f(y):
|
||||
return x * y
|
||||
def f_fwd(y):
|
||||
return f(y), jnp.cos(y)
|
||||
def f_rev(cos_y, g):
|
||||
return (cos_y * g,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
return f
|
||||
|
||||
@jit
|
||||
def g(x, y):
|
||||
return outer(x)(y)
|
||||
|
||||
ans = g(2, 3.)
|
||||
expected = 6.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(g, 1)(2., 3.)
|
||||
expected = jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_tracer_error(self):
|
||||
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
|
||||
# we're testing for the error message that that usage pattern now raises.
|
||||
|
||||
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
||||
def f(x, y):
|
||||
return x * y
|
||||
def f_fwd(x, y):
|
||||
return f(x, y), jnp.cos(y)
|
||||
def f_rev(x, cos_y, g):
|
||||
return (cos_y * g,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
@jit
|
||||
def g(x, y):
|
||||
return f(x, y)
|
||||
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError, "custom_vjp"):
|
||||
_ = g(2, 3.)
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError, "custom_vjp"):
|
||||
_ = api.grad(g, 1)(2., 3.)
|
||||
|
||||
def test_vmap_axes(self):
|
||||
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
||||
|
||||
@ -3482,6 +3720,11 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
|
||||
|
||||
def test_nondiff_argnums_stop_gradient(self):
|
||||
# This test is now skipped because we decided not to support this behavior
|
||||
# anymore (namely, nondiff args can't be tracers), but test_clip_gradient is
|
||||
# a replacement showing behavior we do support.
|
||||
raise unittest.SkipTest("removed support for tracers in nondiff args")
|
||||
|
||||
# https://github.com/google/jax/issues/2784
|
||||
@partial(api.custom_vjp, nondiff_argnums=(0, 1))
|
||||
def _clip_gradient(lo, hi, x):
|
||||
@ -3503,6 +3746,29 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(clip_gradient)(1.) # doesn't crash
|
||||
|
||||
def test_clip_gradient(self):
|
||||
# https://github.com/google/jax/issues/2784
|
||||
@api.custom_vjp
|
||||
def _clip_gradient(lo, hi, x):
|
||||
return x # identity function when not differentiating
|
||||
|
||||
def clip_gradient_fwd(lo, hi, x):
|
||||
return x, (lo, hi,)
|
||||
|
||||
def clip_gradient_bwd(res, g):
|
||||
lo, hi = res
|
||||
return (None, None, jnp.clip(g, lo, hi),)
|
||||
|
||||
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
|
||||
|
||||
def clip_gradient(x):
|
||||
lo = -0.1
|
||||
hi = x + 0.1
|
||||
return _clip_gradient(lo, hi, x)
|
||||
|
||||
g = jax.grad(clip_gradient)(0.1) # doesn't crash
|
||||
self.assertAllClose(g, jnp.array(0.2))
|
||||
|
||||
def test_nestable_vjp(self):
|
||||
# Verify that https://github.com/google/jax/issues/3667 is resolved.
|
||||
def f(x):
|
||||
@ -3536,6 +3802,94 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
y, = z(1.0)(3.0)
|
||||
self.assertAllClose(y, jnp.array(6.0))
|
||||
|
||||
def test_initial_style_vmap_2(self):
|
||||
# https://github.com/google/jax/issues/4173
|
||||
x = jnp.ones((10, 3))
|
||||
|
||||
# Create the custom function
|
||||
@api.custom_vjp
|
||||
def custom_fun(x):
|
||||
return x.sum()
|
||||
def forward(x):
|
||||
return x.sum(), (jnp.ones_like(x),)
|
||||
def backward(res, g):
|
||||
return g*res[0],
|
||||
custom_fun.defvjp(forward, backward)
|
||||
|
||||
def train_fun(x):
|
||||
def summed_fun(x):
|
||||
return api.vmap(custom_fun)(x).sum()
|
||||
return api.grad(summed_fun)(x)
|
||||
|
||||
def scan_body(carry, inputs):
|
||||
x = carry
|
||||
return carry, train_fun(x)
|
||||
|
||||
scan_range = jnp.arange(4)
|
||||
lax.scan(scan_body, x, scan_range) # don't crash
|
||||
|
||||
def test_bwd_closes_over_tracer(self):
|
||||
def f(y):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return 2. * jnp.sin(x)
|
||||
|
||||
def fwd(x):
|
||||
return f(x), ()
|
||||
|
||||
def bwd(_, g):
|
||||
return (2. * jnp.cos(y) * g,) # capture!
|
||||
|
||||
f.defvjp(fwd, bwd)
|
||||
|
||||
return jax.grad(f)(1.)
|
||||
|
||||
ans = jax.jit(f)(2.)
|
||||
self.assertAllClose(ans, 2. * jnp.cos(2.))
|
||||
|
||||
ans = jax.vmap(f)(jnp.arange(3.))
|
||||
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
||||
|
||||
ans = jax.jit(jax.vmap(f))(jnp.arange(3.))
|
||||
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
||||
|
||||
ans = jax.vmap(jax.jit(f))(jnp.arange(3.))
|
||||
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
||||
|
||||
ans = jax.grad(f)(4.)
|
||||
self.assertAllClose(ans, -2. * jnp.sin(4.))
|
||||
|
||||
def test_fwd_closes_over_tracer(self):
|
||||
def f(y):
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return 2. * jnp.sin(x)
|
||||
|
||||
def fwd(x):
|
||||
return f(x), y
|
||||
|
||||
def bwd(y, g):
|
||||
return (2. * jnp.cos(y) * g,) # capture!
|
||||
|
||||
f.defvjp(fwd, bwd)
|
||||
|
||||
return jax.grad(f)(1.)
|
||||
|
||||
ans = jax.jit(f)(2.)
|
||||
self.assertAllClose(ans, 2. * jnp.cos(2.))
|
||||
|
||||
ans = jax.vmap(f)(jnp.arange(3.))
|
||||
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
||||
|
||||
ans = jax.jit(jax.vmap(f))(jnp.arange(3.))
|
||||
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
||||
|
||||
ans = jax.vmap(jax.jit(f))(jnp.arange(3.))
|
||||
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
||||
|
||||
ans = jax.grad(f)(4.)
|
||||
self.assertAllClose(ans, -2. * jnp.sin(4.))
|
||||
|
||||
def test_float0(self):
|
||||
@api.custom_vjp
|
||||
def f(x, _):
|
||||
@ -3571,6 +3925,120 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y),
|
||||
(2., np.zeros(shape=(), dtype=float0)))
|
||||
|
||||
def test_remat(self):
|
||||
@api.custom_vjp
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
def f_fwd(x):
|
||||
return f(x), jnp.cos(x)
|
||||
def f_rev(cos_x, g):
|
||||
return (2 * cos_x * g,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
@api.remat
|
||||
def g(x):
|
||||
return f(f(x))
|
||||
|
||||
ans = g(2.)
|
||||
expected = np.sin(np.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(g)(2.)
|
||||
expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_higher_order(self):
|
||||
@api.custom_vjp
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
def f_fwd(x):
|
||||
return f(x), jnp.cos(x)
|
||||
def f_rev(cos_x, g):
|
||||
return (2 * cos_x * g,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
def g(x):
|
||||
return f(f(x))
|
||||
|
||||
ans = api.grad(api.grad(api.remat(g)))(2.)
|
||||
expected = api.grad(api.grad(g))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.remat(api.grad(g)))(2.)
|
||||
expected = api.grad(api.grad(g))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(api.grad(api.grad(api.remat(g))))(2.)
|
||||
expected = api.grad(api.grad(api.grad(g)))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_bwd_nones(self):
|
||||
@api.custom_vjp
|
||||
def f(x, y):
|
||||
return x * jnp.sin(y)
|
||||
def f_fwd(x, y):
|
||||
return f(x, y), jnp.cos(y)
|
||||
def f_rev(cos, g):
|
||||
return (None, 2 * cos * g)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
ans = api.grad(lambda x: f(x, x))(3.)
|
||||
expected = 2 * jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_bwd_nones_vmap(self):
|
||||
@api.custom_vjp
|
||||
def f(x, y):
|
||||
return x * jnp.sin(y)
|
||||
def f_fwd(x, y):
|
||||
return f(x, y), jnp.cos(y)
|
||||
def f_rev(cos, g):
|
||||
return (None, 2 * cos * g)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.))
|
||||
expected = 2 * jnp.cos(jnp.arange(3.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_bwd_nones_pytree(self):
|
||||
@api.custom_vjp
|
||||
def f(xs, y):
|
||||
x1, x2 = xs
|
||||
return x1 * x2 * jnp.sin(y)
|
||||
def f_fwd(xs, y):
|
||||
return f(xs, y), jnp.cos(y)
|
||||
def f_rev(cos, g):
|
||||
return (None, 2 * cos * g)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
ans = api.grad(lambda x: f((x, x), x))(3.)
|
||||
expected = 2 * jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_custom_vjp_closure_4521(self):
|
||||
# https://github.com/google/jax/issues/4521
|
||||
@api.custom_vjp
|
||||
def g(x, y):
|
||||
return None
|
||||
def g_fwd(x, y):
|
||||
return None, y
|
||||
def g_bwd(residuals, z_bar):
|
||||
assert False
|
||||
|
||||
g.defvjp(g_fwd, g_bwd)
|
||||
|
||||
def f(xs, y):
|
||||
v_g = api.vmap(g, in_axes=(0, None), out_axes=None)
|
||||
v_g(xs, y)
|
||||
|
||||
def scan_body(xs, _):
|
||||
y = jnp.zeros(1)
|
||||
_, vjp_f = api.vjp(f, xs, y)
|
||||
vjp_f(None)
|
||||
return xs, None
|
||||
|
||||
lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash
|
||||
|
||||
def test_float0_bwd_none(self):
|
||||
@api.custom_vjp
|
||||
def f(i, x):
|
||||
@ -3791,35 +4259,6 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(grad_ans, 3. * 4. + np.cos(np.sin(3. * 4)),
|
||||
check_dtypes=False)
|
||||
|
||||
# TODO
|
||||
# def test_defjvp_closure_error(self):
|
||||
# def foo(x):
|
||||
# @api.custom_transforms
|
||||
# def bar(y):
|
||||
# return x * y
|
||||
|
||||
# api.defjvp(bar, lambda y_dot, ans, y: x * y)
|
||||
# return bar(x)
|
||||
# jtu.check_raises(
|
||||
# lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
|
||||
# "Detected differentiation with respect to closed-over values with "
|
||||
# "custom JVP rule, which isn't supported.")
|
||||
|
||||
# TODO
|
||||
# def test_defvjp_closure_error(self):
|
||||
# def foo(x):
|
||||
# @api.custom_transforms
|
||||
# def bar(y):
|
||||
# return x * y
|
||||
|
||||
# api.defvjp(bar, lambda g, ans, y: x * y)
|
||||
# return bar(x)
|
||||
# jtu.check_raises(
|
||||
# lambda: grad(foo)(1.,), ValueError,
|
||||
# "Detected differentiation w.r.t. variables from outside "
|
||||
# "the scope of <jax.custom_transforms function bar>, but defvjp and "
|
||||
# "defvjp_all only support differentiation w.r.t. positional arguments.")
|
||||
|
||||
def test_custom_transforms_eval_with_pytrees(self):
|
||||
@api.custom_transforms
|
||||
def f(x):
|
||||
@ -3930,8 +4369,6 @@ class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
||||
|
||||
class BufferDonationTest(jtu.JaxTestCase):
|
||||
|
||||
# === pmap ===
|
||||
|
||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||
def test_pmap_donate_argnums_invalidates_input(self):
|
||||
move = api.pmap(lambda x: x + x - x, donate_argnums=0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user