lax.pow: fix shape mismatch failure in jvp rule

This commit is contained in:
Jake VanderPlas 2023-08-25 10:05:55 -07:00
parent 3ea0a74fcc
commit 6cec5d4416
2 changed files with 12 additions and 0 deletions

View File

@ -1989,6 +1989,10 @@ def _pow_jvp_lhs(g, ans, x, y):
y_dtype = dtypes.dtype(y)
x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this
if dtypes.issubdtype(y_dtype, np.integer):
if x.shape != y.shape:
shape = broadcast_shapes(x.shape, y.shape)
x = _maybe_broadcast(shape, x)
y = _maybe_broadcast(shape, y)
jac = select(eq(y, _const(y, 0)), _ones(y),
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
else:

View File

@ -1140,6 +1140,14 @@ class LaxAutodiffTest(jtu.JaxTestCase):
with self.assertRaises(NotImplementedError):
jax.jacrev(f)(x)
def testPowShapeMismatch(self):
# Regression test for https://github.com/google/jax/issues/17294
x = lax.iota('float32', 4)
y = 2
actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error
expected = jax.numpy.diag(y * x ** (y - 1))
self.assertArraysEqual(actual, expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())