Temporarily disable integer index check in jnp.take_along_axis.

This check broke some JAX users; disable it to give time to fix them.

PiperOrigin-RevId: 441993808
This commit is contained in:
Peter Hawkins 2022-04-15 05:44:49 -07:00 committed by jax authors
parent 375777f43c
commit 0c1021ad4b

View File

@ -3425,10 +3425,11 @@ def _normalize_index(index, axis_size):
@partial(jit, static_argnames=('axis',))
def take_along_axis(arr, indices, axis: Optional[int]):
_check_arraylike("take_along_axis", arr, indices)
index_dtype = dtypes.dtype(indices)
if not dtypes.issubdtype(index_dtype, integer):
raise TypeError("take_along_axis indices must be of integer type, got "
f"{str(index_dtype)}")
# index_dtype = dtypes.dtype(indices)
# TODO(phawkins): reenalbe this check after fixing callers
# if not dtypes.issubdtype(index_dtype, integer):
# raise TypeError("take_along_axis indices must be of integer type, got "
# f"{str(index_dtype)}")
if axis is None:
if ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"