mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
custom_vjp: automatically handle float0 cotangents
This commit is contained in:
parent
0adbe563aa
commit
b90daf9cda
@ -553,8 +553,10 @@ def _flatten_bwd(in_tree, in_avals, out_trees, *args):
|
||||
"number of arguments to the primal function, but got VJP output "
|
||||
"structure {} for primal input structure {}.")
|
||||
raise TypeError(msg.format(in_tree2, in_tree)) from None
|
||||
yield [zeros_like_aval(aval.at_least_vspace()) if ct is zero else ct
|
||||
for aval, ct in zip(in_avals, cts_in_flat)]
|
||||
# Ignore any None cotangents, and any corresponding to inputs for which the
|
||||
# type doesn't equal the tangent type (i.e. float0s)
|
||||
yield [zeros_like_aval(a.at_least_vspace()) if ct is zero or a != a.at_least_vspace()
|
||||
else ct for a, ct in zip(in_avals, cts_in_flat)]
|
||||
|
||||
|
||||
class CustomVJPCallPrimitive(core.CallPrimitive):
|
||||
|
@ -5351,6 +5351,21 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
|
||||
self.assertAllClose(g_x, 17. * x, check_dtypes=False)
|
||||
|
||||
def test_float0_cotangents_automatically_handled(self):
|
||||
@jax.custom_vjp
|
||||
def f(x, y):
|
||||
return x
|
||||
|
||||
def f_fwd(x, y):
|
||||
return x, None
|
||||
|
||||
def f_bwd(_, zbar):
|
||||
return (0., 1)
|
||||
|
||||
f.defvjp(f_fwd, f_bwd)
|
||||
|
||||
jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash
|
||||
|
||||
|
||||
class CustomTransposeTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user