mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add BatchTrace.process_custom_vjp_call
It was an oversight not to include this! Notice we have BatchTrace.process_custom_jvp_call. In fact, we can use the same function! We just needed the simplest possible post-process-call which just peels and packages. fixes #5440 Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
34e798ff26
commit
0f704514e3
@ -259,6 +259,8 @@ class BatchTrace(Trace):
|
||||
out_dims = out_dims[-len(out_vals) % len(out_dims):]
|
||||
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
||||
|
||||
post_process_custom_vjp_call = post_process_custom_jvp_call
|
||||
|
||||
def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
||||
axis_name: Union[core.AxisName, Tuple[core.AxisName, ...]]
|
||||
) -> bool:
|
||||
|
@ -3866,6 +3866,47 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
expected = jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_closed_over_tracer2(self):
|
||||
def outer(x):
|
||||
@api.custom_vjp
|
||||
def f(y):
|
||||
return x * y
|
||||
def f_fwd(y):
|
||||
return f(y), jnp.cos(y)
|
||||
def f_rev(cos_y, g):
|
||||
return (cos_y * g,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
return f
|
||||
|
||||
@api.vmap
|
||||
def g(x):
|
||||
return outer(x)(3.)
|
||||
|
||||
ans = g(np.arange(3.))
|
||||
expected = np.arange(3.) * 3
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_closed_over_tracer3(self):
|
||||
def outer(x):
|
||||
@api.custom_vjp
|
||||
def f(y):
|
||||
return x * y
|
||||
def f_fwd(y):
|
||||
return f(y), (x, jnp.cos(y))
|
||||
def f_rev(res, g):
|
||||
x, cos_y = res
|
||||
return (cos_y * g * x,)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
return api.grad(f)
|
||||
|
||||
@api.vmap
|
||||
def g(x):
|
||||
return outer(x)(3.)
|
||||
|
||||
ans = g(np.arange(3.))
|
||||
expected = np.cos(3.) * np.arange(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_tracer_error(self):
|
||||
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
|
||||
# we're testing for the error message that that usage pattern now raises.
|
||||
|
Loading…
x
Reference in New Issue
Block a user