mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #17295 from jakevdp:lax-pow-jvp
PiperOrigin-RevId: 560133324
This commit is contained in:
commit
c71eedf529
@ -1989,6 +1989,10 @@ def _pow_jvp_lhs(g, ans, x, y):
|
||||
y_dtype = dtypes.dtype(y)
|
||||
x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this
|
||||
if dtypes.issubdtype(y_dtype, np.integer):
|
||||
if x.shape != y.shape:
|
||||
shape = broadcast_shapes(x.shape, y.shape)
|
||||
x = _maybe_broadcast(shape, x)
|
||||
y = _maybe_broadcast(shape, y)
|
||||
jac = select(eq(y, _const(y, 0)), _ones(y),
|
||||
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
|
||||
else:
|
||||
|
@ -1140,6 +1140,14 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
jax.jacrev(f)(x)
|
||||
|
||||
def testPowShapeMismatch(self):
|
||||
# Regression test for https://github.com/google/jax/issues/17294
|
||||
x = lax.iota('float32', 4)
|
||||
y = 2
|
||||
actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error
|
||||
expected = jax.numpy.diag(y * x ** (y - 1))
|
||||
self.assertArraysEqual(actual, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user