replace int with operator.index part2

This change align the behavior of `ravel_multi_index`, `split` and `indices` to their `numpy` counterparts.
Also ensure size argument of `nonzero` should be integer.
The changes with `*space` are only simplification
This commit is contained in:
YouJiacheng 2022-04-21 13:15:03 +08:00
parent 0ed29b63f0
commit 667d63aa2d
2 changed files with 24 additions and 11 deletions

View File

@ -25,6 +25,22 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
are not of an integer type, matching the behavior of
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
cast to integers.
* {func}`jax.numpy.ravel_multi_index` now raises a `TypeError` if its `dims` argument
is not of an integer type, matching the behavior of
{func}`numpy.ravel_multi_index`. Previously non-integer `dims` was silently
cast to integers.
* {func}`jax.numpy.split` now raises a `TypeError` if its `axis` argument
is not of an integer type, matching the behavior of
{func}`numpy.split`. Previously non-integer `axis` was silently
cast to integers.
* {func}`jax.numpy.indices` now raises a `TypeError` if its dimensions
are not of an integer type, matching the behavior of
{func}`numpy.indices`. Previously non-integer dimensions were silently
cast to integers.
* {func}`jax.numpy.diag` now raises a `TypeError` if its `k` argument
is not of an integer type, matching the behavior of
{func}`numpy.diag`. Previously non-integer `k` was silently
cast to integers.
* Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,

View File

@ -774,7 +774,7 @@ def ravel(a, order="C"):
@_wraps(np.ravel_multi_index)
def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims)
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
_check_arraylike("ravel_multi_index", *multi_index)
for index in multi_index:
if mode == 'raise':
@ -1057,7 +1057,7 @@ The JAX version does not necessarily return a view of the input.
def _split(op, ary, indices_or_sections, axis=0):
_check_arraylike(op, ary)
ary = asarray(ary)
axis = core.concrete_or_error(int, axis, f"in jax.numpy.{op} argument `axis`")
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
size = ary.shape[axis]
if isinstance(indices_or_sections, (tuple, list)):
indices_or_sections = np.array(
@ -1216,7 +1216,7 @@ def nonzero(a, *, size=None, fill_value=None):
mask = a != 0
if size is None:
size = mask.sum()
size = core.concrete_or_error(int, size,
size = core.concrete_or_error(operator.index, size,
"The size argument of jnp.nonzero must be statically specified "
"to use jnp.nonzero within JAX transformations.")
if a.size == 0 or size == 0:
@ -2097,8 +2097,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis: int = 0):
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, int(num), endpoint, retstep, dtype,
operator.index(axis))
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
@ -2160,8 +2159,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
axis: int = 0):
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace")
return _logspace(start, stop, int(num), endpoint, base, dtype,
operator.index(axis))
return _logspace(start, stop, num, endpoint, base, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
@ -2184,8 +2182,7 @@ def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace")
return _geomspace(start, stop, int(num), endpoint, dtype,
operator.index(axis))
return _geomspace(start, stop, num, endpoint, dtype, axis)
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
@ -2267,7 +2264,7 @@ def ix_(*args):
@_wraps(np.indices)
def indices(dimensions, dtype=int32, sparse=False):
dimensions = tuple(
core.concrete_or_error(int, d, "dimensions argument of jnp.indices")
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
for d in dimensions)
N = len(dimensions)
output = []
@ -2486,7 +2483,7 @@ def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1):
@_wraps(np.diag, lax_description=_ARRAY_VIEW_DOC)
def diag(v, k=0):
return _diag(v, int(k))
return _diag(v, operator.index(k))
@partial(jit, static_argnames=('k',))
def _diag(v, k):