add BatchTrace.process_custom_vjp_call

It was an oversight not to include this! Notice we have
BatchTrace.process_custom_jvp_call. In fact, we can use the same
function! We just needed the simplest possible post-process-call which
just peels and packages.

fixes #5440

Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2021-01-19 19:08:23 -08:00
parent 34e798ff26
commit 0f704514e3
2 changed files with 43 additions and 0 deletions

View File

@ -259,6 +259,8 @@ class BatchTrace(Trace):
out_dims = out_dims[-len(out_vals) % len(out_dims):]
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
post_process_custom_vjp_call = post_process_custom_jvp_call
def _main_trace_for_axis_names(main_trace: core.MainTrace,
axis_name: Union[core.AxisName, Tuple[core.AxisName, ...]]
) -> bool:

View File

@ -3866,6 +3866,47 @@ class CustomVJPTest(jtu.JaxTestCase):
expected = jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracer2(self):
def outer(x):
@api.custom_vjp
def f(y):
return x * y
def f_fwd(y):
return f(y), jnp.cos(y)
def f_rev(cos_y, g):
return (cos_y * g,)
f.defvjp(f_fwd, f_rev)
return f
@api.vmap
def g(x):
return outer(x)(3.)
ans = g(np.arange(3.))
expected = np.arange(3.) * 3
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracer3(self):
def outer(x):
@api.custom_vjp
def f(y):
return x * y
def f_fwd(y):
return f(y), (x, jnp.cos(y))
def f_rev(res, g):
x, cos_y = res
return (cos_y * g * x,)
f.defvjp(f_fwd, f_rev)
return api.grad(f)
@api.vmap
def g(x):
return outer(x)(3.)
ans = g(np.arange(3.))
expected = np.cos(3.) * np.arange(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_tracer_error(self):
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
# we're testing for the error message that that usage pattern now raises.