From a353e3eafa3a9b0449dd99115e6f7ec5ae4f63e7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 14 Oct 2021 11:54:28 -0700 Subject: [PATCH] jnp.take/jnp.take_along_axis: require array inputs --- CHANGELOG.md | 3 +++ jax/_src/numpy/lax_numpy.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8605b0e9..2e4c60d8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * New features: * `jax.random.choice` and `jax.random.permutation` now support multidimensional arrays and an optional `axis` argument ({jax-issue}`#8158`) +* Breaking changes: + * `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs + (see {jax-issue}`#7737`) ## jaxlib 0.1.73 (Unreleased) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 326d58f33..0e16e64aa 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5419,7 +5419,7 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None): def _take(a, indices, axis: Optional[int] = None, out=None, mode=None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.take is not supported.") - _check_arraylike("take", a) + _check_arraylike("take", a, indices) a = asarray(a) indices = asarray(indices) @@ -5474,7 +5474,7 @@ def _normalize_index(index, axis_size): @_wraps(np.take_along_axis, update_doc=False) @partial(jit, static_argnames=('axis',)) def take_along_axis(arr, indices, axis: Optional[int]): - _check_arraylike("take_along_axis", arr) + _check_arraylike("take_along_axis", arr, indices) if axis is None: if ndim(indices) != 1: msg = "take_along_axis indices must be 1D if axis=None, got shape {}"