mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #25831 from jax-ml:avoid-float0-tracers
PiperOrigin-RevId: 714058085
This commit is contained in:
commit
4f106b8a27
@ -511,7 +511,7 @@ class JVPTrace(Trace):
|
||||
return map(partial(maybe_jvp_tracer, self), ps_out, ts_out)
|
||||
|
||||
def maybe_jvp_tracer(trace, primal, tangent):
|
||||
if type(tangent) is Zero:
|
||||
if type(tangent) is Zero or dtype(tangent) == float0:
|
||||
return primal
|
||||
else:
|
||||
return JVPTracer(trace, primal, tangent)
|
||||
|
Loading…
x
Reference in New Issue
Block a user