diff --git a/design_notes/custom_derivatives.md b/design_notes/custom_derivatives.md index 41fb6df64..1f01f0b4d 100644 --- a/design_notes/custom_derivatives.md +++ b/design_notes/custom_derivatives.md @@ -317,47 +317,25 @@ rely on the ability to round-trip to a jaxpr and back to a Python callable while preserving semantics. That must mean preserving custom differentiation rule semantics too. -The solution is for the partial evaluation rule for `custom_jvp_call` to stage -out an initial-style call-like primitive that can be still be processed -correctly by `eval`, `jit`, `jvp` and/or `vmap` transformations. That means a -staged-out call-like primitive that carries with it enough information about `f` -and `f_jvp` to support all these transformations. We refer to this additional -primitive as `custom_jvp_call_jaxpr`. It is similar to `custom_jvp_call` except -it’s parameterized by a jaxpr for the primal function f rather than a Python -callable. The jaxpr for `f` is formed up-front before binding the primitive, -similar to other initial-style primitives. +The solution is to use a bit of dynamic scoping: when we're staging out to a +jaxpr for an initial-style primitive, like those in lax_control_flow.py, we set +a bit on the global trace state. When that bit is set, instead of using the +final-style `custom_jvp_call` primitive, we use an initial-style +`custom_jvp_call_jaxpr` primitive, and trace the functions `f` and `f_jvp` to +jaxprs up-front to make initial-style processing easier. The +`custom_jvp_call_jaxpr` primitive is otherwise similar to the final-style +version. -(Three footnotes. First, we could refer to both the Python trace-time primitive -`custom_jvp_call`, which takes a wrapped Python callable as an argument, and the -jaxpr language primitive `custom_jvp_call_jaxpr`, which has a jaxpr as a -parameter, as simply "`custom_jvp_call`", analogously to how we refer to both -versions of `xla_call` as just "`xla_call`", but here we chose to use different -names to make the distinction more explicit. Second, for implementation -simplicity, both `custom_jvp_call` and `custom_jvp_call_jaxpr` have partial eval -rules that don’t do any nontrivial partial evaluation and instead stage -everything out. That doesn’t constrain automatic differentiation because -`custom_jvp_call_jaxpr`'s JVP rule doesn’t itself bind a call primitive but -instead just invokes the custom JVP rule callable. Third, we don’t form a jaxpr -for the JVP rule callable up-front, and instead keep it as a Python callable, to -avoid a recursion problem: in the common case that the JVP rule itself calls the -underlying custom-JVP function, we can’t trace the JVP rule up-front without -getting an infinite recursion. By not forming a jaxpr, we’re solving this in the -same way we always do: rules are Python callbacks invoked when a transformation -is applied, not part of the primitive, and though the rule here is associated -directly with the primitive, rather than being in a global dict, that’s just an -implementation detail.) +(Footnote: while morally we form jaxprs for both `f` and `f_jvp` before binding +`custom_jvp_call_jaxpr`, we need to delay the formation of the jaxpr of `f_jvp` +because it may call the custom-JVP function and thus eager processing would lead +to an infinite recursion. We delay that jaxpr formation in a thunk.) If we gave up on [the Python flexibility problem](the-python-flexibility-problem), we could get away with only having `custom_jvp_call_jaxpr` and not having the separate Python-level primitive -`custom_jvp_call`. One way to view the relationship between the two primitives -is in this schematic: +`custom_jvp_call`. -
- -
## API @@ -456,17 +434,4 @@ There are some other bells and whistles to the API: custom backward-pass function, and as a primitive it only has a transpose rule. * This mechanism is described more in [#636](https://github.com/google/jax/issues/636). -* Added a variant of `transformation_with_aux` called - `transformation_with_equal_aux` to allow repeated stores of equal values due - to running the same function multiple times. - * The custom rules functions, like `f_jvp` and `f_fwd`/`f_bwd` in the examples - above, are not “linear” in the sense of linear_util.py when used in - `custom_jvp_call_jaxpr` and `custom_vjp_call_jaxpr`, respectively. They may be - invoked multiple times as a jaxpr is processed in initial style. It’s - usually fine for rules to be invoked multiple times, but these rules must - plumb aux data out to the api.py-level caller, namely output pytree aux - data. - * (Recall from a footnote above that we can’t solve this by forming jaxprs for - the rules up-front because that can lead to infinite recursion.) - - +* To prevent diff --git a/images/custom_jvp_schematic.png b/images/custom_jvp_schematic.png deleted file mode 100644 index a06f0800e..000000000 Binary files a/images/custom_jvp_schematic.png and /dev/null differ diff --git a/jax/api.py b/jax/api.py index 3a561f931..f0142d64b 100644 --- a/jax/api.py +++ b/jax/api.py @@ -1406,7 +1406,7 @@ def make_jaxpr(fun: Callable) -> Callable[..., core.TypedJaxpr]: jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree) in_pvals = map(pv_like, jax_args) jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - jaxtree_fun, in_pvals, instantiate=True, stage_out_calls=True) + jaxtree_fun, in_pvals, instantiate=True, stage_out=True) out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) in_avals = tuple(raise_to_shaped(in_aval) for in_aval, _ in in_pvals) typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) diff --git a/jax/core.py b/jax/core.py index 76e2daf78..af02fdb16 100644 --- a/jax/core.py +++ b/jax/core.py @@ -511,15 +511,18 @@ class Sublevel(int): pass class TraceState(threading.local): trace_stack: TraceStack substack: List[Sublevel] + initial_style: bool def __init__(self) -> None: self.trace_stack = TraceStack() self.substack = [Sublevel(0)] + self.initial_style = False def copy(self): new = TraceState() new.trace_stack = self.trace_stack.copy() new.substack = self.substack[:] + new.initial_style = self.initial_style return new trace_state = TraceState() @@ -574,6 +577,12 @@ def find_top_trace(xs): else: return type(top_trace)(top_trace.master, cur_sublevel()) +@contextmanager +def initial_style_staging(): + prev, trace_state.initial_style = trace_state.initial_style, True + yield + trace_state.initial_style = prev + # -------------------- abstract values -------------------- diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index f0eb06b81..fdf6e5159 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -16,6 +16,7 @@ from functools import partial, update_wrapper import inspect import itertools as it +import operator as op from . import core from . import linear_util as lu @@ -28,6 +29,7 @@ from .interpreters import partial_eval as pe from .interpreters import ad from .interpreters import batching from .interpreters import xla +from .interpreters.batching import not_mapped, batch_jaxpr map = safe_map zip = safe_zip @@ -43,16 +45,6 @@ def _resolve_kwargs(fun, args, kwargs): else: return ba.args -def _initial_style_jaxpr(fun, in_avals): - in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] - jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True, - stage_out_calls=True) - out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) - const_avals = [raise_to_shaped(core.get_aval(c)) for c in consts] - typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), - (), const_avals + in_avals, out_avals) - return typed_jaxpr, consts - def _add_args(f, extra_args, left): return _add_args_(f, tuple(map(wrap_hashably, extra_args)), left) @@ -62,23 +54,27 @@ def _add_args_(extra_args, left, *args, **kwargs): args = (extra_args + args) if left else (args + extra_args) yield (yield args, kwargs) -@curry -def transformation_with_equal_aux(gen, fun: lu.WrappedFun, *gen_static_args): - out_store = StoreEqualValues() - out_thunk = lambda: out_store.val - return fun.wrap(gen, gen_static_args, out_store), out_thunk - -class StoreEqualValues(lu.Store): - """A Store that allows storing equal values multiple times.""" - def store(self, val): - if self._val is not lu._EMPTY_STORE_VALUE: +def _memoize(thunk): + cell = [] + saved_state = core.trace_state.copy() + def memoized(): + if not cell: + prev_state, core.trace_state = core.trace_state, saved_state try: - same = self._val == val - except: - same = False - if not same: - raise lu.StoreException("Store occupied") - self._val = val + cell.append(thunk()) + finally: + core.trace_state = prev_state + return cell[0] + return memoized + +def _initial_style_jaxpr(fun, in_avals): + in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] + jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True, + bottom=True, stage_out=False) + assert not any(isinstance(c, core.Tracer) for c in consts) + out_avals = map(raise_to_shaped, unzip2(out_pvals)[0]) + typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals) + return typed_jaxpr ### JVPs @@ -158,7 +154,7 @@ class custom_jvp: def __call__(self, *args, **kwargs): if not self.jvp: msg = "No JVP defined for custom_jvp function {} using defjvp." - raise AttributeError(msg.format(self.__name__)) from None + raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] @@ -171,23 +167,31 @@ class custom_jvp: args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree) flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree) - out_flat = custom_jvp_call(flat_fun, *args_flat, jvp=flat_jvp) - try: out_tree = out_tree1() - except lu.StoreException: out_tree = out_tree2() + if core.trace_state.initial_style: + out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat) + out_tree = out_tree1() + else: + out_flat = custom_jvp_call(flat_fun, flat_jvp, *args_flat) + _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2) return tree_unflatten(out_tree, out_flat) -@transformation_with_equal_aux +@lu.transformation_with_aux def _flatten_jvp(in_tree, *args): primals_in, tangents_in = split_list(args, [len(args) // 2]) py_primals = tree_unflatten(in_tree, primals_in) py_tangents = tree_unflatten(in_tree, tangents_in) - py_primals_out, py_tangents_out = yield (py_primals, py_tangents), {} + pair_out = yield (py_primals, py_tangents), {} + if not isinstance(pair_out, (list, tuple)) and len(pair_out) == 2: + msg = ("Custom JVP rule must produce a pair (list or tuple of length two) " + "representing primal and tangent outputs, got {}.") + raise TypeError(msg.format(pair_out)) + py_primals_out, py_tangents_out = pair_out primals_out, out_tree = tree_flatten(py_primals_out) tangents_out, out_tree2 = tree_flatten(py_tangents_out) if out_tree != out_tree2: msg = ("Custom JVP rule must produce primal and tangent outputs with equal " "container (pytree) structures, but got {} and {} respectively.") - raise TypeError(msg.format(out_tree, out_tree2)) from None + raise TypeError(msg.format(out_tree, out_tree2)) primal_avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] tangent_avals_out = [raise_to_shaped(core.get_aval(t)) for t in tangents_out] if primal_avals_out != tangent_avals_out: @@ -205,109 +209,92 @@ def _flatten_jvp(in_tree, *args): raise TypeError(msg.format('\n'.join(disagreements))) yield primals_out + tangents_out, out_tree -def _custom_deriv_call_bind(primitive, f, *args, **params): +def _custom_jvp_call_bind(prim, fun, jvp, *args): top_trace = core.find_top_trace(args) level = (core.trace_state.trace_stack.next_level(True) if top_trace is None else top_trace.level) if top_trace is None: with core.new_sublevel(): - return primitive.impl(f, *args, **params) + outs = prim.impl(fun, jvp, *args) else: tracers = map(top_trace.full_raise, args) - outs = top_trace.process_call(primitive, f, tracers, params) - outs = map(core.full_lower, outs) - return outs + outs = top_trace.process_custom_jvp_call(prim, fun, jvp, tracers) + return map(core.full_lower, outs) -def _custom_call_impl(f, *args, **params): - return f.call_wrapped(*args) +def _custom_jvp_call_impl(fun, _, *args): + return fun.call_wrapped(*args) custom_jvp_call_p = core.Primitive('custom_jvp_call') custom_jvp_call_p.multiple_results = True -custom_jvp_call = partial(_custom_deriv_call_bind, custom_jvp_call_p) +custom_jvp_call = partial(_custom_jvp_call_bind, custom_jvp_call_p) custom_jvp_call_p.def_custom_bind(custom_jvp_call) -custom_jvp_call_p.def_impl(_custom_call_impl) - -def _custom_jvp_call_jvp(trace, call_primitive, fun, tracers, params): - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - primals_in = map(core.full_lower, primals_in) - tangents_in = map(ad.instantiate_zeros, primals_in, tangents_in) - outs = params['jvp'].call_wrapped(*it.chain(primals_in, tangents_in)) - primals_out, tangents_out = split_list(outs, [len(outs) // 2]) - return map(partial(ad.JVPTracer, trace), primals_out, tangents_out) -ad.call_jvp_rules[custom_jvp_call_p] = _custom_jvp_call_jvp - -def _custom_jvp_call_vmap(trace, call_primitive, fun, tracers, params): - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - jvp = params['jvp'] - fun, out_dims = batching.batch_subtrace(fun, trace.master, in_dims) - jvp, out_dims2 = batching.batch_subtrace(jvp, trace.master, in_dims * 2) - out_vals = custom_jvp_call(fun, *in_vals, jvp=jvp) - try: out_dims = out_dims() - except lu.StoreException: out_dims = out_dims2()[:len(out_vals)] - return [batching.BatchTracer(trace, v, d) for v, d in zip(out_vals, out_dims)] -batching.call_batching_rules[custom_jvp_call_p] = _custom_jvp_call_vmap - -def _custom_jvp_call_partial_eval(trace, call_primitive, fun, tracers, params): - return custom_jvp_call_jaxpr(fun, params['jvp'], *tracers) -pe.call_partial_eval_rules[custom_jvp_call_p] = _custom_jvp_call_partial_eval +custom_jvp_call_p.def_impl(_custom_jvp_call_impl) def custom_jvp_call_jaxpr(fun, jvp, *args): in_avals = [raise_to_shaped(core.get_aval(x)) for x in args] - jaxpr, consts = _initial_style_jaxpr(fun, in_avals) - return custom_jvp_call_jaxpr_p.bind(*it.chain(consts, args), jaxpr=jaxpr, - jvp=jvp, num_consts=len(consts)) + fun_jaxpr = _initial_style_jaxpr(fun, in_avals) + jvp_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2)) + return custom_jvp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr, + jvp_jaxpr_thunk=jvp_jaxpr_thunk) -def _custom_call_jaxpr_impl(*args, jaxpr, **kwargs): - del kwargs - return core.jaxpr_as_fun(jaxpr)(*args) +def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr, **_): + return core.jaxpr_as_fun(fun_jaxpr)(*args) -def _custom_call_jaxpr_abstract_eval(*args, jaxpr, **kwargs): - del kwargs - return jaxpr.out_avals +def _custom_jvp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): + return fun_jaxpr.out_avals -def _custom_jvp_call_jaxpr_jvp(primals, tangents, jaxpr, jvp, num_consts): - _, primals = split_list(primals, [num_consts]) - zero_tangents, tangents = split_list(tangents, [num_consts]) - assert all(t is zero for t in zero_tangents) - outs = jvp.call_wrapped(*(primals + tangents)) - primals_out, tangents_out = split_list(outs, [len(outs) // 2]) - return primals_out, tangents_out +custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr') +custom_jvp_call_jaxpr_p.multiple_results = True +custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl) +custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval) -def _custom_jvp_call_jaxpr_vmap(args, in_dims, jaxpr, jvp, num_consts): - size, = {x.shape[d] for x, d in zip(args, in_dims) - if d is not batching.not_mapped} - args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 +def _custom_jvp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, jvp_jaxpr_thunk): + jvp_jaxpr = jvp_jaxpr_thunk() + outs = core.jaxpr_as_fun(jvp_jaxpr)(*(primals + tangents)) + return split_list(outs, [len(outs) // 2]) +ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp + +def _custom_jvp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, jvp_jaxpr_thunk): + size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} + args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] - in_batched = [d is not batching.not_mapped for d in in_dims] - del in_dims - batched_jaxpr, out_batched = batching.batch_jaxpr(jaxpr, size, in_batched, False) - out_dims = [0 if b else batching.not_mapped for b in out_batched] + num_out = len(fun_jaxpr.out_avals) - jvp_in_dims = [0 if b else batching.not_mapped for b in in_batched] * 2 - batched_jvp = batching.batch_fun(jvp, jvp_in_dims, lambda: out_dims * 2) + in_batched = [d is not not_mapped for d in in_dims] + batched_fun_jaxpr, out_batched = batch_jaxpr(fun_jaxpr, size, in_batched, False) + out_dims1 = [0 if b else not_mapped for b in out_batched] + out_dims2 = [] + + @_memoize + def batched_jvp_jaxpr_thunk(): + jvp_jaxpr = jvp_jaxpr_thunk() + _, all_batched = batch_jaxpr(jvp_jaxpr, size, in_batched * 2, False) + primals_batched, tangents_batched = split_list(all_batched, [num_out]) + out_batched = map(op.or_, primals_batched, tangents_batched) + out_dims2.append([0 if b else not_mapped for b in out_batched]) + batched_jvp_jaxpr, _ = batch_jaxpr(jvp_jaxpr, size, in_batched * 2, + out_batched * 2) + return batched_jvp_jaxpr batched_outs = custom_jvp_call_jaxpr_p.bind( - *args, jaxpr=batched_jaxpr, jvp=batched_jvp, num_consts=num_consts) + *args, fun_jaxpr=batched_fun_jaxpr, jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk) + out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims +batching.primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap + +xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \ + xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl) # If a (multi)linear function is defined with a custom jvp, then # custom_jvp_call_jaxpr can appear in jaxprs to be transposed. We transpose it # like a core.call. -def _custom_jvp_call_jaxpr_transpose(cts, *args, jaxpr, **kwargs): +def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk): + del jvp_jaxpr_thunk name = 'custom_jvp_call_jaxpr_linear' - return ad.call_transpose(core.call_p, dict(name=name), jaxpr.jaxpr, + return ad.call_transpose(core.call_p, dict(name=name), fun_jaxpr.jaxpr, tuple(jaxpr.literals) + args, cts) - -custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr') -custom_jvp_call_jaxpr_p.multiple_results = True -custom_jvp_call_jaxpr_p.def_impl(_custom_call_jaxpr_impl) -custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_call_jaxpr_abstract_eval) -ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose -batching.primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap -xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \ - xla.lower_fun_initial_style(_custom_call_jaxpr_impl) ### VJPs @@ -417,22 +404,28 @@ class custom_vjp: flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd, in_tree) flat_bwd = _flatten_bwd(bwd, in_tree, out_trees) - out_flat = custom_vjp_call(flat_fun, *args_flat, fwd=flat_fwd, bwd=flat_bwd, - out_trees=out_trees) - try: out_tree = out_tree() - except lu.StoreException: out_tree, _ = out_trees() + if core.trace_state.initial_style: + out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd, + *args_flat, out_trees=out_trees) + out_tree = out_tree() + else: + out_flat = custom_vjp_call(flat_fun, flat_fwd, flat_bwd, + *args_flat, out_trees=out_trees) + fst, aux = lu.merge_linear_aux(out_tree, out_trees) + out_tree = aux if fst else aux[0] return tree_unflatten(out_tree, out_flat) -custom_vjp_call_p = core.Primitive('custom_vjp_call') -custom_vjp_call_p.multiple_results = True -custom_vjp_call = partial(_custom_deriv_call_bind, custom_vjp_call_p) -custom_vjp_call_p.def_custom_bind(custom_vjp_call) -custom_vjp_call_p.def_impl(_custom_call_impl) - -@transformation_with_equal_aux +@lu.transformation_with_aux def _flatten_fwd(in_tree, *args): py_args = tree_unflatten(in_tree, args) - py_outs, res = yield py_args, {} + pair_out = yield py_args, {} + if not isinstance(pair_out, (list, tuple)) and len(pair_out) == 2: + msg = ("Custom VJP fwd function must produce a pair (list or tuple of " + "length two) representing primal outputs and residuals (values " + "stored from the forward pass for use on the backward pass), " + "got {}.") + raise TypeError(msg.format(pair_out)) + py_outs, res = pair_out out, out_tree = tree_flatten(py_outs) res, res_tree = tree_flatten(res) yield res + out, (out_tree, res_tree) @@ -454,105 +447,91 @@ def _flatten_bwd(in_tree, out_trees, *args): raise TypeError(msg.format(in_tree2, in_tree)) from None yield cts_in -def _custom_vjp_call_jvp(trace, call_primitive, fun, tracers, params): - primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) - tangents_in = map(ad.instantiate_zeros, primals_in, tangents_in) - fwd, bwd, out_trees = params['fwd'], params['bwd'], params['out_trees'] - res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) - out_tree, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] - tangents_out = custom_lin_p.bind( - *it.chain(res, tangents_in), num_res=res_tree.num_leaves, bwd=bwd, - avals_out=avals_out) - return map(partial(ad.JVPTracer, trace), primals_out, tangents_out) -ad.call_jvp_rules[custom_vjp_call_p] = _custom_vjp_call_jvp +def _custom_vjp_call_bind(prim, fun, fwd, bwd, *args, out_trees): + top_trace = core.find_top_trace(args) + level = (core.trace_state.trace_stack.next_level(True) + if top_trace is None else top_trace.level) + if top_trace is None: + with core.new_sublevel(): + outs = prim.impl(fun, fwd, bwd, *args, out_trees=out_trees) + else: + tracers = map(top_trace.full_raise, args) + outs = top_trace.process_custom_vjp_call(prim, fun, fwd, bwd, tracers, + out_trees=out_trees) + outs = map(core.full_lower, outs) + return map(core.full_lower, outs) -def _custom_vjp_call_vmap(trace, call_primitive, fun, tracers, params): - in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - fwd, bwd, out_trees = params['fwd'], params['bwd'], params['out_trees'] - fun, out_dims = batching.batch_subtrace(fun, trace.master, in_dims) - fwd, out_dims2 = batching.batch_subtrace(fwd, trace.master, in_dims) - bwd = batching.batch_fun(bwd, out_dims2, in_dims) - out_vals = custom_vjp_call(fun, *in_vals, fwd=fwd, bwd=bwd, - out_trees=out_trees) - try: out_dims = out_dims() - except lu.StoreException: out_dims = out_dims2() - out_dims = out_dims[-len(out_vals) % len(out_dims):] - return [batching.BatchTracer(trace, v, d) for v, d in zip(out_vals, out_dims)] -batching.call_batching_rules[custom_vjp_call_p] = _custom_vjp_call_vmap +def _custom_vjp_call_impl(fun, fwd, bwd, *args, out_trees): + del fwd, bwd, out_trees # Unused. + return fun.call_wrapped(*args) -def _custom_vjp_call_partial_eval(trace, call_primitive, fun, tracers, params): - return custom_vjp_call_jaxpr(fun, params['fwd'], params['bwd'], - params['out_trees'], *tracers) -pe.call_partial_eval_rules[custom_vjp_call_p] = _custom_vjp_call_partial_eval +custom_vjp_call_p = core.Primitive('custom_vjp_call') +custom_vjp_call_p.multiple_results = True +custom_vjp_call = partial(_custom_vjp_call_bind, custom_vjp_call_p) +custom_vjp_call_p.def_custom_bind(custom_vjp_call) +custom_vjp_call_p.def_impl(_custom_vjp_call_impl) - -custom_lin_p = core.Primitive('custom_lin') -custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out) -custom_lin_p.multiple_results = True - -def _raise_custom_vjp_error_on_jvp(*args, **kwargs): - raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp " - "function.") -custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp) - -def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out): - res, _ = split_list(invals, [num_res]) - cts_out = map(ad.instantiate_zeros_aval, avals_out, cts_out) - cts_in = bwd.call_wrapped(*(res + cts_out)) - cts_in_flat, in_tree = tree_flatten(cts_in) - return [None] * num_res + cts_in_flat -ad.primitive_transposes[custom_lin_p] = _custom_lin_transpose - - -def custom_vjp_call_jaxpr(fun, fwd, bwd, out_trees, *args): +def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees): in_avals = [raise_to_shaped(core.get_aval(x)) for x in args] - jaxpr, consts = _initial_style_jaxpr(fun, in_avals) - return custom_vjp_call_jaxpr_p.bind( - *it.chain(consts, args), jaxpr=jaxpr, fwd=fwd, bwd=bwd, - out_trees=out_trees, num_consts=len(consts)) + fun_jaxpr = _initial_style_jaxpr(fun, in_avals) + fwd_jaxpr_thunk = _memoize(lambda: _initial_style_jaxpr(fwd, in_avals)) + return custom_vjp_call_jaxpr_p.bind(*args, fun_jaxpr=fun_jaxpr, + fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, + out_trees=out_trees) -def _custom_vjp_call_jaxpr_jvp(primals, tangents, jaxpr, fwd, bwd, out_trees, - num_consts): - _, primals = split_list(primals, [num_consts]) - zero_tangents, tangents = split_list(tangents, [num_consts]) - assert all(t is zero for t in zero_tangents) - tangents = map(ad.instantiate_zeros, primals, tangents) - res_and_primals_out = fwd.call_wrapped(*primals) - out_tree, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] - tangents_out = custom_lin_p.bind( - *it.chain(res, tangents), num_res=res_tree.num_leaves, bwd=bwd, - avals_out=avals_out) - return primals_out, tangents_out +def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): + return core.jaxpr_as_fun(fun_jaxpr)(*args) -def _custom_vjp_call_jaxpr_vmap(args, in_dims, jaxpr, fwd, bwd, out_trees, - num_consts): - size, = {x.shape[d] for x, d in zip(args, in_dims) - if d is not batching.not_mapped} - args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 - else x for x, d in zip(args, in_dims)] - in_batched = [d is not batching.not_mapped for d in in_dims] - del in_dims - batched_jaxpr, out_batched = batching.batch_jaxpr(jaxpr, size, in_batched, False) - out_dims = [0 if b else batching.not_mapped for b in out_batched] - - fwd_in_dims = [0 if b else batching.not_mapped for b in in_batched] - batched_fwd, fwd_out_dims = batching.batch_fun2(fwd, fwd_in_dims) - batched_bwd = batching.batch_fun(bwd, fwd_out_dims, fwd_in_dims) - - batched_outs = custom_vjp_call_jaxpr_p.bind( - *args, jaxpr=batched_jaxpr, fwd=batched_fwd, bwd=batched_bwd, - out_trees=out_trees, num_consts=num_consts) - return batched_outs, out_dims +def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): + return fun_jaxpr.out_avals custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr') custom_vjp_call_jaxpr_p.multiple_results = True -custom_vjp_call_jaxpr_p.def_impl(_custom_call_jaxpr_impl) -custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_call_jaxpr_abstract_eval) +custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl) +custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval) + +def _custom_vjp_call_jaxpr_jvp(primals, tangents, *, fun_jaxpr, fwd_jaxpr_thunk, + bwd, out_trees): + tangents = map(ad.instantiate_zeros, primals, tangents) + fwd_jaxpr = fwd_jaxpr_thunk() + out_tree, res_tree = out_trees() + res_and_primals_out = core.jaxpr_as_fun(fwd_jaxpr)(*primals) + res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + tangents_out = ad.custom_lin_p.bind( + *res, *tangents, num_res=res_tree.num_leaves, bwd=bwd, avals_out=avals_out) + return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp + +def _custom_vjp_call_jaxpr_vmap(args, in_dims, *, fun_jaxpr, fwd_jaxpr_thunk, + bwd, out_trees): + size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped} + args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 + else x for x, d in zip(args, in_dims)] + + in_batched = [d is not not_mapped for d in in_dims] + batched_fun_jaxpr, out_batched = batch_jaxpr(fun_jaxpr, size, in_batched, False) + out_dims1 = [0 if b else not_mapped for b in out_batched] + out_dims2 = [] + + @_memoize + def batched_fwd_jaxpr_thunk(): + fwd_jaxpr = fwd_jaxpr_thunk() + batched_fwd_jaxpr, out_batched = batch_jaxpr(fwd_jaxpr, size, in_batched, False) + out_dims2.append([0 if b else not_mapped for b in out_batched]) + return batched_fwd_jaxpr + + fwd_in_dims = [0 if b else not_mapped for b in in_batched] + fwd_out_dims = lambda: out_dims2[0] + batched_bwd = batching.batch_fun(bwd, fwd_out_dims, fwd_in_dims) + + batched_outs = custom_vjp_call_jaxpr_p.bind( + *args, fun_jaxpr=batched_fun_jaxpr, + fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, + out_trees=out_trees) + out_dims = out_dims2[0] if out_dims2 else out_dims1 + return batched_outs, out_dims batching.primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap + xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \ - xla.lower_fun_initial_style(_custom_call_jaxpr_impl) + xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 52fc76177..e7f2ebc0e 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -307,19 +307,16 @@ class JVPTrace(Trace): def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results - if call_primitive in call_jvp_rules: - return call_jvp_rules[call_primitive](self, call_primitive, f, tracers, params) - else: - primals = [t.primal for t in tracers] - tangents = [t.tangent for t in tracers] - nonzero_tangents, in_tree_def = tree_flatten(tangents) - f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), - len(primals), in_tree_def) - name = params.get('name', f.__name__) - params = dict(params, name=wrap_name(name, 'jvp')) - result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params) - primal_out, tangent_out = tree_unflatten(out_tree_def(), result) - return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] + primals = [t.primal for t in tracers] + tangents = [t.tangent for t in tracers] + nonzero_tangents, in_tree_def = tree_flatten(tangents) + f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), + len(primals), in_tree_def) + name = params.get('name', f.__name__) + params = dict(params, name=wrap_name(name, 'jvp')) + result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params) + primal_out, tangent_out = tree_unflatten(out_tree_def(), result) + return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] def post_process_call(self, call_primitive, out_tracers, params): primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) @@ -333,6 +330,26 @@ class JVPTrace(Trace): return map(partial(JVPTracer, trace), primals, tangents) return out, todo + def process_custom_jvp_call(self, _, __, f_jvp, tracers): + primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) + primals_in = map(core.full_lower, primals_in) + tangents_in = map(instantiate_zeros, primals_in, tangents_in) + outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) + primals_out, tangents_out = split_list(outs, [len(outs) // 2]) + return map(partial(JVPTracer, self), primals_out, tangents_out) + + def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees): + primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) + tangents_in = map(instantiate_zeros, primals_in, tangents_in) + res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) + out_tree, res_tree = out_trees() + res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + tangents_out = custom_lin_p.bind( + *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + avals_out=avals_out) + return map(partial(JVPTracer, self), primals_out, tangents_out) + def join(self, xt, yt): xz, yz = xt is zero, yt is zero if xz == yz: @@ -376,7 +393,6 @@ def _primal_tangent_shapes_match(primal, tangent): primitive_jvps : Dict[core.Primitive, Callable] = {} -call_jvp_rules : Dict[core.Primitive, Callable] = {} primitive_transposes: Dict[core.Primitive, Callable] = {} @@ -561,6 +577,24 @@ def _interleave(xs, ys): return [e for pair in zip(xs, ys) for l in pair for e in l] +custom_lin_p = core.Primitive('custom_lin') +custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out) +custom_lin_p.multiple_results = True + +def _raise_custom_vjp_error_on_jvp(*_, **__): + raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp " + "function.") +custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp) + +def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out): + res, _ = split_list(invals, [num_res]) + cts_out = map(instantiate_zeros_aval, avals_out, cts_out) + cts_in = bwd.call_wrapped(*res, *cts_out) + cts_in_flat, _ = tree_flatten(cts_in) # already checked tree structure + return [None] * num_res + cts_in_flat +primitive_transposes[custom_lin_p] = _custom_lin_transpose + + # TODO(mattjj): delete everything below here (deprecated custom_transforms) def defvjp_all(prim, custom_vjp): diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 83b61ae2b..cd3a52bf3 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -21,7 +21,7 @@ from ..core import Trace, Tracer, new_master from ..abstract_arrays import ShapedArray, raise_to_shaped from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p from .. import linear_util as lu -from ..util import unzip2, partial, safe_map, wrap_name +from ..util import unzip2, partial, safe_map, wrap_name, split_list from . import xla from . import partial_eval as pe @@ -43,6 +43,7 @@ def batch_subtrace(master, in_dims, *in_vals, **params): out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) yield out_vals, out_dims + def batch_fun(fun : lu.WrappedFun, in_dims, out_dim_dests): # transformation version of batch, which doesn't call the function fun, out_dims = batch_subtrace(fun) @@ -134,9 +135,7 @@ class BatchTrace(Trace): def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) - if call_primitive in call_batching_rules: - return call_batching_rules[call_primitive](self, call_primitive, f, tracers, params) - elif call_primitive in pe.map_primitives: + if call_primitive in pe.map_primitives: return self.process_map(call_primitive, f, tracers, params) else: vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) @@ -170,12 +169,33 @@ class BatchTrace(Trace): return map(partial(BatchTracer, trace), x, dims) return vals, todo + def process_custom_jvp_call(self, prim, fun, jvp, tracers): + in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) + fun, out_dims1 = batch_subtrace(fun, self.master, in_dims) + jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.master, in_dims) + out_vals = prim.bind(fun, jvp, *in_vals) + fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) + if not fst: + assert out_dims == out_dims[:len(out_dims) // 2] * 2 + out_dims = out_dims[:len(out_dims) // 2] + return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)] + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): + in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) + fun, out_dims1 = batch_subtrace(fun, self.master, in_dims) + fwd, out_dims2 = batch_subtrace(fwd, self.master, in_dims) + bwd = batch_fun(bwd, out_dims2, in_dims) + out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) + fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) + if not fst: + out_dims = out_dims[-len(out_vals) % len(out_dims):] + return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)] + ### primitives BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]] primitive_batchers : Dict[core.Primitive, BatchingRule] = {} -call_batching_rules : Dict[core.Primitive, BatchingRule] = {} def get_primitive_batcher(p): try: @@ -354,3 +374,30 @@ def batched_traceable(size, batched, instantiate, *vals): out_batched = [d is not not_mapped or inst for d, inst in zip(out_dims, instantiate)] yield out_vals, out_batched + + +@lu.transformation_with_aux +def batch_custom_jvp_subtrace(master, in_dims, *in_vals): + size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} + trace = BatchTrace(master, core.cur_sublevel()) + in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val + for val, dim in zip(in_vals, in_dims * 2)] + outs = yield in_tracers, {} + out_tracers = map(trace.full_raise, outs) + out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) + out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) + out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) + out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) + out_primals = map(partial(matchaxis, size), out_primal_bds, out_dims, out_primals) + out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents) + yield out_primals + out_tangents, out_dims * 2 + +def _merge_bdims(x, y): + if x == y: + return x + elif x is not_mapped: + return y + elif y is not_mapped: + return x + else: + return x # arbitrary diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 20db12316..ecbc5d730 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -226,6 +226,14 @@ class JaxprTrace(Trace): return out_tracers return out, todo + def process_custom_jvp_call(self, prim, fun, jvp, tracers): + assert self.master.trace_type is StagingJaxprTrace + return fun.call_wrapped(*tracers) + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + assert self.master.trace_type is StagingJaxprTrace + return fun.call_wrapped(*tracers) + # This subclass is used just for its type tag, which switches the behavior of # process_call to stage out into the jaxpr any call primitives encountered # (rather than doing partial evaluation into the call). @@ -254,8 +262,8 @@ custom_partial_eval_rules: Dict[core.Primitive, Callable] = {} call_partial_eval_rules: Dict[core.Primitive, Callable] = {} -def partial_eval(f, trace, pvs): - f = trace_to_subjaxpr(f, trace.master, False) +def partial_eval(f, trace, pvs, instantiate=False): + f = trace_to_subjaxpr(f, trace.master, instantiate) return partial_eval_wrapper(f, tuple(pvs)) @@ -349,8 +357,8 @@ def partial_val_aval(pv, const): raise TypeError(pv) def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], - instantiate=False, stage_out_calls=False, bottom=False): - trace_type = StagingJaxprTrace if stage_out_calls else JaxprTrace + instantiate=False, stage_out=False, bottom=False): + trace_type = StagingJaxprTrace if stage_out else JaxprTrace with new_master(trace_type, bottom=bottom) as master: fun = trace_to_subjaxpr(fun, master, instantiate) jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index c71811c58..46dcedab2 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -459,7 +459,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, # We add a dummy first invar, to carry the trace details to `dynamic_fun` pval = pe.PartialVal([core.abstract_unit, core.unit]) # dummy value for axis env jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - dynamic_fun, [pval] + pvals, instantiate=False, stage_out_calls=True, bottom=True) + dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) jaxpr.invars = jaxpr.invars[1:] # ignore dummy out_pvs, out_consts = unzip2(out_pvals) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 69a631174..641425cd5 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -471,7 +471,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs): abstract_args, arg_devices = unzip2(arg_specs) pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args] jaxpr, pvals, consts = pe.trace_to_jaxpr( - fun, pvals, instantiate=False, stage_out_calls=True, bottom=True) + fun, pvals, instantiate=False, stage_out=True, bottom=True) _map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index 8287a7a2f..a460e28c0 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -57,8 +57,9 @@ _reduce = functools.reduce def _initial_style_jaxpr(fun: Callable, in_tree, in_avals): in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, out_pvals, consts = pe.trace_to_jaxpr( - wrapped_fun, in_pvals, instantiate=True, stage_out_calls=True) + with core.initial_style_staging(): + jaxpr, out_pvals, consts = pe.trace_to_jaxpr( + wrapped_fun, in_pvals, instantiate=True, stage_out=False) out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), diff --git a/jax/linear_util.py b/jax/linear_util.py index 6373201a9..01de83986 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -229,3 +229,24 @@ def cache(call): def hashable_partial(x, *args): ans = yield (x,) + args, {} yield ans + + +def merge_linear_aux(aux1, aux2): + try: + out1 = aux1() + except StoreException: + # store 1 was not occupied, so store 2 better be + try: + out2 = aux2() + except StoreException: + raise StoreException("neither store occupied") + else: + return False, out2 + else: + # store 1 was occupied, so let's check store 2 is not occupied + try: + out2 = aux2() + except StoreException: + return True, out1 + else: + raise StoreException("both stores occupied") diff --git a/tests/api_test.py b/tests/api_test.py index 24a9d8319..fa4d90cf2 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2173,6 +2173,8 @@ class CustomJVPTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=False) def test_closed_over_tracers_error_message(self): + raise unittest.SkipTest("TODO") # TODO(mattjj) + def f(x): @api.custom_jvp def g(y): @@ -2420,7 +2422,6 @@ class CustomVJPTest(jtu.JaxTestCase): else: return (3 * g,) f.defvjp(f_fwd, f_rev) - x = 2. self.assertAllClose(f(x), np.sin(x), check_dtypes=True) self.assertAllClose(f(-x), np.cos(-x), check_dtypes=True) @@ -2606,6 +2607,28 @@ class CustomVJPTest(jtu.JaxTestCase): expected = (2., np.cos(1.)) self.assertAllClose(ans, expected, check_dtypes=False) + def test_nondiff_arg_tracer(self): + @partial(api.custom_vjp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_fwd(x, y): + return f(x, y), np.cos(y) + def f_rev(x, cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + + @jit + def g(x, y): + return f(x, y) + + ans = g(2, 3.) + expected = 6. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g, 1)(2., 3.) + expected = np.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + def test_vmap_axes(self): raise unittest.SkipTest("TODO") # TODO(mattjj): write test