custom_vjp: automatically handle float0 cotangents

This commit is contained in:
Matthew Johnson 2021-08-17 16:18:57 -07:00
parent 0adbe563aa
commit b90daf9cda
2 changed files with 19 additions and 2 deletions

View File

@ -553,8 +553,10 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
"number of arguments to the primal function, but got VJP output "
"structure {} for primal input structure {}.")
raise TypeError(msg.format(in_tree2, in_tree)) from None
yield [zeros_like_aval(aval.at_least_vspace()) if ct is zero else ct
for aval, ct in zip(in_avals, cts_in_flat)]
# Ignore any None cotangents, and any corresponding to inputs for which the
# type doesn't equal the tangent type (i.e. float0s)
yield [zeros_like_aval(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace()
else ct for a, ct in zip(in_avals, cts_in_flat)]
class CustomVJPCallPrimitive(core.CallPrimitive):

View File

@ -5351,6 +5351,21 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
self.assertAllClose(g_x, 17. * x, check_dtypes=False)
def test_float0_cotangents_automatically_handled(self):
@jax.custom_vjp
def f(x, y):
return x
def f_fwd(x, y):
return x, None
def f_bwd(_, zbar):
return (0., 1)
f.defvjp(f_fwd, f_bwd)
jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash
class CustomTransposeTest(jtu.JaxTestCase):