mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Temporarily disable integer index check in jnp.take_along_axis.
This check broke some JAX users; disable it to give time to fix them. PiperOrigin-RevId: 441993808
This commit is contained in:
parent
375777f43c
commit
0c1021ad4b
@ -3425,10 +3425,11 @@ def _normalize_index(index, axis_size):
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def take_along_axis(arr, indices, axis: Optional[int]):
|
||||
_check_arraylike("take_along_axis", arr, indices)
|
||||
index_dtype = dtypes.dtype(indices)
|
||||
if not dtypes.issubdtype(index_dtype, integer):
|
||||
raise TypeError("take_along_axis indices must be of integer type, got "
|
||||
f"{str(index_dtype)}")
|
||||
# index_dtype = dtypes.dtype(indices)
|
||||
# TODO(phawkins): reenalbe this check after fixing callers
|
||||
# if not dtypes.issubdtype(index_dtype, integer):
|
||||
# raise TypeError("take_along_axis indices must be of integer type, got "
|
||||
# f"{str(index_dtype)}")
|
||||
if axis is None:
|
||||
if ndim(indices) != 1:
|
||||
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user