mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add brcast to deal with inconsistent shapes.
This commit is contained in:
parent
dfd3d93350
commit
bbf0d5c55e
@ -1442,8 +1442,8 @@ ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
||||
|
||||
atan2_p = standard_binop([_float, _float], 'atan2')
|
||||
ad.defjvp(atan2_p,
|
||||
lambda g, x, y: g * y / (square(x) + square(y)),
|
||||
lambda g, x, y: g * -x / (square(x) + square(y)))
|
||||
lambda g, x, y: _brcast(g, y) * (y / (square(x) + square(y))),
|
||||
lambda g, x, y: _brcast(g, x) * -x / (square(x) + square(y)))
|
||||
|
||||
lgamma_p = standard_unop(_float, 'lgamma')
|
||||
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user