mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #14728 from mattjj:custom-vjp-bwd-wrapped-fun
PiperOrigin-RevId: 513375826
This commit is contained in:
commit
a9421a806f
@ -560,7 +560,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
|
||||
out_type)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees)
|
||||
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
|
||||
@ -680,7 +680,7 @@ class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
fwd, env_trace_todo2 = process_env_traces_fwd(
|
||||
fwd, top_trace and top_trace.level, out_trees)
|
||||
tracers = map(top_trace.full_raise, args) # type: ignore
|
||||
bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
|
||||
bwd_ = lambda *args: bwd(*args)
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
|
||||
out_trees=out_trees)
|
||||
fst, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
@ -749,7 +749,7 @@ mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun(
|
||||
def _custom_vjp_call_jaxpr_jvp(
|
||||
primals, tangents, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
|
||||
bwd: Callable, out_trees: Callable, num_consts: int):
|
||||
_, args = split_list(primals, [num_consts])
|
||||
consts_dot, args_dot = split_list(tangents, [num_consts])
|
||||
if any(type(t) is not Zero for t in consts_dot):
|
||||
@ -772,7 +772,7 @@ ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
|
||||
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
|
||||
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
|
||||
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
|
||||
bwd: Callable, out_trees: Callable, num_consts: int):
|
||||
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
||||
else x for x, d in zip(args, in_dims)]
|
||||
|
||||
|
@ -748,7 +748,7 @@ custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)
|
||||
def _custom_lin_transpose(cts_out, *invals, num_res, bwd, out_avals):
|
||||
res, _ = split_list(invals, [num_res])
|
||||
cts_out = map(instantiate_zeros_aval, out_avals, cts_out)
|
||||
cts_in = bwd.call_wrapped(*res, *cts_out)
|
||||
cts_in = bwd(*res, *cts_out)
|
||||
return [None] * num_res + list(cts_in)
|
||||
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
||||
|
||||
|
@ -776,11 +776,16 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
|
||||
out_tangent_bds, out_dims, out_tangents)
|
||||
yield out_primals + out_tangents, out_dims * 2
|
||||
|
||||
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
|
||||
bwd, out_dims_thunk = batch_subtrace(bwd)
|
||||
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type,
|
||||
spmd_axis_name)
|
||||
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)
|
||||
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests,
|
||||
main_type, spmd_axis_name):
|
||||
def new_bwd(*args):
|
||||
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd))
|
||||
bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims, main_type,
|
||||
spmd_axis_name)
|
||||
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
|
||||
out_dim_dests)
|
||||
return bwd_.call_wrapped(*args)
|
||||
return new_bwd
|
||||
|
||||
@lu.transformation
|
||||
def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
|
@ -308,6 +308,9 @@ class _HashableCallableShim:
|
||||
return self.fun == other.fun
|
||||
return self.fun == other
|
||||
|
||||
def __repr__(self):
|
||||
return f'_HashableCallableShim({repr(self.fun)})'
|
||||
|
||||
|
||||
class Partial(functools.partial):
|
||||
"""A version of functools.partial that works in pytrees.
|
||||
|
@ -8385,6 +8385,18 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(f)(A([1.])) # doesn't crash
|
||||
|
||||
def test_vmap_vjp_called_twice(self):
|
||||
# https://github.com/google/jax/pull/14728
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
return x
|
||||
f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,))
|
||||
|
||||
_, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.]))
|
||||
f_vjp(jnp.array([3.]))
|
||||
f_vjp(jnp.array([3.])) # doesn't crash
|
||||
|
||||
|
||||
def transpose_unary(f, x_example):
|
||||
def transposed(y):
|
||||
x, = api.linear_transpose(f, x_example)(y)
|
||||
|
Loading…
x
Reference in New Issue
Block a user