Avoid creating float0 JVPTracers

This commit is contained in:
Dougal 2025-01-10 10:43:54 -05:00
parent c61b2f6b81
commit ba9b2ca5f6

View File

@ -511,7 +511,7 @@ class JVPTrace(Trace):
return map(partial(maybe_jvp_tracer, self), ps_out, ts_out)
def maybe_jvp_tracer(trace, primal, tangent):
if type(tangent) is Zero:
if type(tangent) is Zero or dtype(tangent) == float0:
return primal
else:
return JVPTracer(trace, primal, tangent)