diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 6c4a0744c..85ade5e02 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -511,7 +511,7 @@ class JVPTrace(Trace): return map(partial(maybe_jvp_tracer, self), ps_out, ts_out) def maybe_jvp_tracer(trace, primal, tangent): - if type(tangent) is Zero: + if type(tangent) is Zero or dtype(tangent) == float0: return primal else: return JVPTracer(trace, primal, tangent)