From 17f5658db8078a540a7475119656f02b6413880d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 16 Jan 2024 08:46:44 -0800 Subject: [PATCH] jnp.diff: support scalar prepend/append --- jax/_src/numpy/lax_numpy.py | 4 ++-- tests/lax_numpy_test.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 29479cb60..4b1c40e3f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -670,7 +670,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, combined: list[Array] = [] if prepend is not None: util.check_arraylike("diff", prepend) - if isscalar(prepend): + if not ndim(prepend): shape = list(arr.shape) shape[axis] = 1 prepend = broadcast_to(prepend, tuple(shape)) @@ -680,7 +680,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, if append is not None: util.check_arraylike("diff", append) - if isscalar(append): + if not ndim(append): shape = list(arr.shape) shape[axis] = 1 append = broadcast_to(append, tuple(shape)) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7f8ad632e..29db81e3d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2744,6 +2744,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) + def testDiffPrepoendScalar(self): + # Regression test for https://github.com/google/jax/issues/19362 + x = jnp.arange(10) + result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) + + x = np.array(x) + result_numpy = np.diff(x, prepend=x[0], append=x[-1]) + + self.assertArraysEqual(result_jax, result_numpy) + @jtu.sample_product( op=["zeros", "ones"], shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32),