mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Avoid creating float0 JVPTracers
This commit is contained in:
parent
c61b2f6b81
commit
ba9b2ca5f6
@ -511,7 +511,7 @@ class JVPTrace(Trace):
|
||||
return map(partial(maybe_jvp_tracer, self), ps_out, ts_out)
|
||||
|
||||
def maybe_jvp_tracer(trace, primal, tangent):
|
||||
if type(tangent) is Zero:
|
||||
if type(tangent) is Zero or dtype(tangent) == float0:
|
||||
return primal
|
||||
else:
|
||||
return JVPTracer(trace, primal, tangent)
|
||||
|
Loading…
x
Reference in New Issue
Block a user