mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
remove an extraneous replace_float0s
This caused a test failure when trying to land #4008.
This commit is contained in:
parent
7f4e115a6a
commit
8678287644
@ -323,9 +323,6 @@ class JVPTrace(Trace):
|
||||
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
|
||||
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
tangents_in = map(instantiate_zeros, tangents_in)
|
||||
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
|
||||
# currently handle float0s
|
||||
tangents_in = replace_float0s(primals_in, tangents_in)
|
||||
res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
|
||||
out_tree, res_tree = out_trees()
|
||||
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
|
||||
|
@ -3571,6 +3571,20 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y),
|
||||
(2., np.zeros(shape=(), dtype=float0)))
|
||||
|
||||
def test_float0_bwd_none(self):
|
||||
@api.custom_vjp
|
||||
def f(i, x):
|
||||
return jnp.sin(x)
|
||||
def f_fwd(i, x):
|
||||
return f(i, x), jnp.cos(x)
|
||||
def f_rev(cos_x, g):
|
||||
return (None, 2 * cos_x * g)
|
||||
f.defvjp(f_fwd, f_rev)
|
||||
|
||||
ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash
|
||||
expected = 2 * jnp.cos(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user