remove an extraneous replace_float0s

This caused a test failure when trying to land #4008.
This commit is contained in:
Matthew Johnson 2020-10-15 16:18:43 -07:00
parent 7f4e115a6a
commit 8678287644
2 changed files with 14 additions and 3 deletions

View File

@ -323,9 +323,6 @@ class JVPTrace(Trace):
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
tangents_in = map(instantiate_zeros, tangents_in)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
tangents_in = replace_float0s(primals_in, tangents_in)
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
out_tree, res_tree = out_trees()
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])

View File

@ -3571,6 +3571,20 @@ class CustomVJPTest(jtu.JaxTestCase):
self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y),
(2., np.zeros(shape=(), dtype=float0)))
def test_float0_bwd_none(self):
@api.custom_vjp
def f(i, x):
return jnp.sin(x)
def f_fwd(i, x):
return f(i, x), jnp.cos(x)
def f_rev(cos_x, g):
return (None, 2 * cos_x * g)
f.defvjp(f_fwd, f_rev)
ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash
expected = 2 * jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
class InvertibleADTest(jtu.JaxTestCase):