Fix dtype canonicalization in jnp.indices.

`jnp.indices` was hard coded to default to `dtype = np.int32`, but it
should default to the canonicalized `np.int64`.

Fixes https://github.com/google/jax/issues/22501
This commit is contained in:
Dan Foreman-Mackey 2024-07-18 10:54:42 -04:00
parent d7b821b04d
commit 991187aaa8
3 changed files with 13 additions and 7 deletions

View File

@ -4517,17 +4517,19 @@ def ix_(*args: ArrayLike) -> tuple[Array, ...]:
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: Literal[False] = False) -> Array: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
*, sparse: Literal[True]) -> tuple[Array, ...]: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: bool = False) -> Array | tuple[Array, ...]: ...
@util.implements(np.indices)
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: bool = False) -> Array | tuple[Array, ...]:
dtypes.check_user_dtype_supported(dtype, "indices")
dtype = dtype or dtypes.canonicalize_dtype(int_)
dimensions = tuple(
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
for d in dimensions)

View File

@ -462,13 +462,13 @@ def imag(x: ArrayLike, /) -> Array: ...
index_exp = _np.index_exp
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: Literal[False] = False) -> Array: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
*, sparse: Literal[True]) -> tuple[Array, ...]: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
sparse: builtins.bool = False) -> Array | tuple[Array, ...]: ...
inexact = _np.inexact

View File

@ -4613,6 +4613,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testIndiciesDefaultDtype(self):
self.assertEqual(jnp.indices((2, 3)).dtype,
dtypes.canonicalize_dtype(np.int64))
@jtu.sample_product(
shape=nonzerodim_shapes,
dtype=all_dtypes,