Merge pull request #14728 from mattjj:custom-vjp-bwd-wrapped-fun

PiperOrigin-RevId: 513375826
This commit is contained in:
jax authors 2023-03-01 16:38:58 -08:00
commit a9421a806f
5 changed files with 30 additions and 10 deletions

View File

@ -560,7 +560,7 @@ class custom_vjp(Generic[ReturnValue]):
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
@ -680,7 +680,7 @@ class CustomVJPCallPrimitive(core.CallPrimitive):
fwd, env_trace_todo2 = process_env_traces_fwd(
fwd, top_trace and top_trace.level, out_trees)
tracers = map(top_trace.full_raise, args) # type: ignore
bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
bwd_ = lambda *args: bwd(*args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
out_trees=out_trees)
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
@ -749,7 +749,7 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
def _custom_vjp_call_jaxpr_jvp(
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
bwd: Callable, out_trees: Callable, num_consts: int):
_, args = split_list(primals, [num_consts])
consts_dot, args_dot = split_list(tangents, [num_consts])
if any(type(t) is not Zero for t in consts_dot):
@ -772,7 +772,7 @@ ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
bwd: Callable, out_trees: Callable, num_consts: int):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]

View File

@ -748,7 +748,7 @@ custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
res, _ = split_list(invals, [num_res])
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
cts_in = bwd.call_wrapped(*res, *cts_out)
cts_in = bwd(*res, *cts_out)
return [None] * num_res + list(cts_in)
primitive_transposes[custom_lin_p] = _custom_lin_transpose

View File

@ -776,11 +776,16 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
out_tangent_bds, out_dims, out_tangents)
yield out_primals + out_tangents, out_dims * 2
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
bwd, out_dims_thunk = batch_subtrace(bwd)
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests,
main_type, spmd_axis_name):
def new_bwd(*args):
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd))
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
out_dim_dests)
return bwd_.call_wrapped(*args)
return new_bwd
@lu.transformation
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):

View File

@ -308,6 +308,9 @@ class _HashableCallableShim:
return self.fun == other.fun
return self.fun == other
def __repr__(self):
return f'_HashableCallableShim({repr(self.fun)})'
class Partial(functools.partial):
"""A version of functools.partial that works in pytrees.

View File

@ -8385,6 +8385,18 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.grad(f)(A([1.])) # doesn't crash
def test_vmap_vjp_called_twice(self):
# https://github.com/google/jax/pull/14728
@jax.custom_vjp
def f(x):
return x
f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,))
_, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.]))
f_vjp(jnp.array([3.]))
f_vjp(jnp.array([3.])) # doesn't crash
def transpose_unary(f, x_example):
def transposed(y):
x, = api.linear_transpose(f, x_example)(y)