Merge pull request #19381 from jakevdp:fix-diff

PiperOrigin-RevId: 598884533
This commit is contained in:
jax authors 2024-01-16 10:35:05 -08:00
commit c0d51e7dde
2 changed files with 12 additions and 2 deletions

View File

@ -670,7 +670,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
combined: list[Array] = []
if prepend is not None:
util.check_arraylike("diff", prepend)
if isscalar(prepend):
if not ndim(prepend):
shape = list(arr.shape)
shape[axis] = 1
prepend = broadcast_to(prepend, tuple(shape))
@ -680,7 +680,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
if append is not None:
util.check_arraylike("diff", append)
if isscalar(append):
if not ndim(append):
shape = list(arr.shape)
shape[axis] = 1
append = broadcast_to(append, tuple(shape))

View File

@ -2744,6 +2744,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
def testDiffPrepoendScalar(self):
# Regression test for https://github.com/google/jax/issues/19362
x = jnp.arange(10)
result_jax = jnp.diff(x, prepend=x[0], append=x[-1])
x = np.array(x)
result_numpy = np.diff(x, prepend=x[0], append=x[-1])
self.assertArraysEqual(result_jax, result_numpy)
@jtu.sample_product(
op=["zeros", "ones"],
shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32),