mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix dtype canonicalization in jnp.indices
.
`jnp.indices` was hard coded to default to `dtype = np.int32`, but it should default to the canonicalized `np.int64`. Fixes https://github.com/google/jax/issues/22501
This commit is contained in:
parent
d7b821b04d
commit
991187aaa8
@ -4517,17 +4517,19 @@ def ix_(*args: ArrayLike) -> tuple[Array, ...]:
|
||||
|
||||
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
|
||||
sparse: Literal[False] = False) -> Array: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
|
||||
*, sparse: Literal[True]) -> tuple[Array, ...]: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
|
||||
sparse: bool = False) -> Array | tuple[Array, ...]: ...
|
||||
@util.implements(np.indices)
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
|
||||
sparse: bool = False) -> Array | tuple[Array, ...]:
|
||||
dtypes.check_user_dtype_supported(dtype, "indices")
|
||||
dtype = dtype or dtypes.canonicalize_dtype(int_)
|
||||
dimensions = tuple(
|
||||
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
|
||||
for d in dimensions)
|
||||
|
@ -462,13 +462,13 @@ def imag(x: ArrayLike, /) -> Array: ...
|
||||
index_exp = _np.index_exp
|
||||
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
|
||||
sparse: Literal[False] = False) -> Array: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
|
||||
*, sparse: Literal[True]) -> tuple[Array, ...]: ...
|
||||
@overload
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
|
||||
def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None,
|
||||
sparse: builtins.bool = False) -> Array | tuple[Array, ...]: ...
|
||||
|
||||
inexact = _np.inexact
|
||||
|
@ -4613,6 +4613,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testIndiciesDefaultDtype(self):
|
||||
self.assertEqual(jnp.indices((2, 3)).dtype,
|
||||
dtypes.canonicalize_dtype(np.int64))
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=nonzerodim_shapes,
|
||||
dtype=all_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user