mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
f08c3b746f
commit
1171626b99
@ -1363,7 +1363,7 @@ ad.primitive_transposes[add_p] = _add_transpose
|
||||
|
||||
def _sub_transpose(t, x, y):
|
||||
assert x is None and y is None # computation must be linear, not affine
|
||||
return [t, neg(t)]
|
||||
return [t, neg(t) if t is not ad_util.zero else ad_util.zero]
|
||||
|
||||
sub_p = standard_binop([_num, _num], 'sub')
|
||||
ad.defjvp(sub_p,
|
||||
|
Loading…
x
Reference in New Issue
Block a user