jnp.take/jnp.take_along_axis: require array inputs

This commit is contained in:
Jake VanderPlas 2021-10-14 11:54:28 -07:00
parent 267e1ec5a4
commit a353e3eafa
2 changed files with 5 additions and 2 deletions

View File

@ -15,6 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* New features:
* `jax.random.choice` and `jax.random.permutation` now support
multidimensional arrays and an optional `axis` argument ({jax-issue}`#8158`)
* Breaking changes:
* `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs
(see {jax-issue}`#7737`)
## jaxlib 0.1.73 (Unreleased)

View File

@ -5419,7 +5419,7 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
_check_arraylike("take", a)
_check_arraylike("take", a, indices)
a = asarray(a)
indices = asarray(indices)
@ -5474,7 +5474,7 @@ def _normalize_index(index, axis_size):
@_wraps(np.take_along_axis, update_doc=False)
@partial(jit, static_argnames=('axis',))
def take_along_axis(arr, indices, axis: Optional[int]):
_check_arraylike("take_along_axis", arr)
_check_arraylike("take_along_axis", arr, indices)
if axis is None:
if ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"