From 6cec5d4416ad8b757b493b34a6463aed7a03a4b7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 25 Aug 2023 10:05:55 -0700 Subject: [PATCH] lax.pow: fix shape mismatch failure in jvp rule --- jax/_src/lax/lax.py | 4 ++++ tests/lax_autodiff_test.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 78d52cef4..3dc7566ad 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1989,6 +1989,10 @@ def _pow_jvp_lhs(g, ans, x, y): y_dtype = dtypes.dtype(y) x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this if dtypes.issubdtype(y_dtype, np.integer): + if x.shape != y.shape: + shape = broadcast_shapes(x.shape, y.shape) + x = _maybe_broadcast(shape, x) + y = _maybe_broadcast(shape, y) jac = select(eq(y, _const(y, 0)), _ones(y), mul(_replace_zero(y), pow(x, sub(y, _ones(y))))) else: diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index faa22afb1..45aac1988 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -1140,6 +1140,14 @@ class LaxAutodiffTest(jtu.JaxTestCase): with self.assertRaises(NotImplementedError): jax.jacrev(f)(x) + def testPowShapeMismatch(self): + # Regression test for https://github.com/google/jax/issues/17294 + x = lax.iota('float32', 4) + y = 2 + actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error + expected = jax.numpy.diag(y * x ** (y - 1)) + self.assertArraysEqual(actual, expected) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())