mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #19381 from jakevdp:fix-diff
PiperOrigin-RevId: 598884533
This commit is contained in:
commit
c0d51e7dde
@ -670,7 +670,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
combined: list[Array] = []
|
||||
if prepend is not None:
|
||||
util.check_arraylike("diff", prepend)
|
||||
if isscalar(prepend):
|
||||
if not ndim(prepend):
|
||||
shape = list(arr.shape)
|
||||
shape[axis] = 1
|
||||
prepend = broadcast_to(prepend, tuple(shape))
|
||||
@ -680,7 +680,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
|
||||
|
||||
if append is not None:
|
||||
util.check_arraylike("diff", append)
|
||||
if isscalar(append):
|
||||
if not ndim(append):
|
||||
shape = list(arr.shape)
|
||||
shape[axis] = 1
|
||||
append = broadcast_to(append, tuple(shape))
|
||||
|
@ -2744,6 +2744,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testDiffPrepoendScalar(self):
|
||||
# Regression test for https://github.com/google/jax/issues/19362
|
||||
x = jnp.arange(10)
|
||||
result_jax = jnp.diff(x, prepend=x[0], append=x[-1])
|
||||
|
||||
x = np.array(x)
|
||||
result_numpy = np.diff(x, prepend=x[0], append=x[-1])
|
||||
|
||||
self.assertArraysEqual(result_jax, result_numpy)
|
||||
|
||||
@jtu.sample_product(
|
||||
op=["zeros", "ones"],
|
||||
shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32),
|
||||
|
Loading…
x
Reference in New Issue
Block a user