From 667d63aa2d4fbf7c9da73aab0e24c5c4c33cb5ba Mon Sep 17 00:00:00 2001 From: YouJiacheng <83971976+YouJiacheng@users.noreply.github.com> Date: Thu, 21 Apr 2022 13:15:03 +0800 Subject: [PATCH] replace int with operator.index part2 This change align the behavior of `ravel_multi_index`, `split` and `indices` to their `numpy` counterparts. Also ensure size argument of `nonzero` should be integer. The changes with `*space` are only simplification --- CHANGELOG.md | 16 ++++++++++++++++ jax/_src/numpy/lax_numpy.py | 19 ++++++++----------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9897b7c8d..65e965b2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,22 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. are not of an integer type, matching the behavior of {func}`numpy.take_along_axis`. Previously non-integer indices were silently cast to integers. + * {func}`jax.numpy.ravel_multi_index` now raises a `TypeError` if its `dims` argument + is not of an integer type, matching the behavior of + {func}`numpy.ravel_multi_index`. Previously non-integer `dims` was silently + cast to integers. + * {func}`jax.numpy.split` now raises a `TypeError` if its `axis` argument + is not of an integer type, matching the behavior of + {func}`numpy.split`. Previously non-integer `axis` was silently + cast to integers. + * {func}`jax.numpy.indices` now raises a `TypeError` if its dimensions + are not of an integer type, matching the behavior of + {func}`numpy.indices`. Previously non-integer dimensions were silently + cast to integers. + * {func}`jax.numpy.diag` now raises a `TypeError` if its `k` argument + is not of an integer type, matching the behavior of + {func}`numpy.diag`. Previously non-integer `k` was silently + cast to integers. * Deprecations * Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7712c8fa9..91e928075 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -774,7 +774,7 @@ def ravel(a, order="C"): @_wraps(np.ravel_multi_index) def ravel_multi_index(multi_index, dims, mode='raise', order='C'): assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" - dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims) + dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) _check_arraylike("ravel_multi_index", *multi_index) for index in multi_index: if mode == 'raise': @@ -1057,7 +1057,7 @@ The JAX version does not necessarily return a view of the input. def _split(op, ary, indices_or_sections, axis=0): _check_arraylike(op, ary) ary = asarray(ary) - axis = core.concrete_or_error(int, axis, f"in jax.numpy.{op} argument `axis`") + axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`") size = ary.shape[axis] if isinstance(indices_or_sections, (tuple, list)): indices_or_sections = np.array( @@ -1216,7 +1216,7 @@ def nonzero(a, *, size=None, fill_value=None): mask = a != 0 if size is None: size = mask.sum() - size = core.concrete_or_error(int, size, + size = core.concrete_or_error(operator.index, size, "The size argument of jnp.nonzero must be statically specified " "to use jnp.nonzero within JAX transformations.") if a.size == 0 or size == 0: @@ -2097,8 +2097,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis: int = 0): num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") - return _linspace(start, stop, int(num), endpoint, retstep, dtype, - operator.index(axis)) + return _linspace(start, stop, num, endpoint, retstep, dtype, axis) @partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis')) def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, @@ -2160,8 +2159,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis: int = 0): num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") - return _logspace(start, stop, int(num), endpoint, base, dtype, - operator.index(axis)) + return _logspace(start, stop, num, endpoint, base, dtype, axis) @partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, @@ -2184,8 +2182,7 @@ def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0): num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") - return _geomspace(start, stop, int(num), endpoint, dtype, - operator.index(axis)) + return _geomspace(start, stop, num, endpoint, dtype, axis) @partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0): @@ -2267,7 +2264,7 @@ def ix_(*args): @_wraps(np.indices) def indices(dimensions, dtype=int32, sparse=False): dimensions = tuple( - core.concrete_or_error(int, d, "dimensions argument of jnp.indices") + core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices") for d in dimensions) N = len(dimensions) output = [] @@ -2486,7 +2483,7 @@ def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1): @_wraps(np.diag, lax_description=_ARRAY_VIEW_DOC) def diag(v, k=0): - return _diag(v, int(k)) + return _diag(v, operator.index(k)) @partial(jit, static_argnames=('k',)) def _diag(v, k):