mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jnp.take/jnp.take_along_axis: require array inputs
This commit is contained in:
parent
267e1ec5a4
commit
a353e3eafa
@ -15,6 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* New features:
|
||||
* `jax.random.choice` and `jax.random.permutation` now support
|
||||
multidimensional arrays and an optional `axis` argument ({jax-issue}`#8158`)
|
||||
* Breaking changes:
|
||||
* `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs
|
||||
(see {jax-issue}`#7737`)
|
||||
|
||||
## jaxlib 0.1.73 (Unreleased)
|
||||
|
||||
|
@ -5419,7 +5419,7 @@ def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
def _take(a, indices, axis: Optional[int] = None, out=None, mode=None):
|
||||
if out is not None:
|
||||
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
|
||||
_check_arraylike("take", a)
|
||||
_check_arraylike("take", a, indices)
|
||||
a = asarray(a)
|
||||
indices = asarray(indices)
|
||||
|
||||
@ -5474,7 +5474,7 @@ def _normalize_index(index, axis_size):
|
||||
@_wraps(np.take_along_axis, update_doc=False)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def take_along_axis(arr, indices, axis: Optional[int]):
|
||||
_check_arraylike("take_along_axis", arr)
|
||||
_check_arraylike("take_along_axis", arr, indices)
|
||||
if axis is None:
|
||||
if ndim(indices) != 1:
|
||||
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user