mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
replace int with operator.index part2
This change align the behavior of `ravel_multi_index`, `split` and `indices` to their `numpy` counterparts. Also ensure size argument of `nonzero` should be integer. The changes with `*space` are only simplification
This commit is contained in:
parent
0ed29b63f0
commit
667d63aa2d
16
CHANGELOG.md
16
CHANGELOG.md
@ -25,6 +25,22 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
are not of an integer type, matching the behavior of
|
||||
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
|
||||
cast to integers.
|
||||
* {func}`jax.numpy.ravel_multi_index` now raises a `TypeError` if its `dims` argument
|
||||
is not of an integer type, matching the behavior of
|
||||
{func}`numpy.ravel_multi_index`. Previously non-integer `dims` was silently
|
||||
cast to integers.
|
||||
* {func}`jax.numpy.split` now raises a `TypeError` if its `axis` argument
|
||||
is not of an integer type, matching the behavior of
|
||||
{func}`numpy.split`. Previously non-integer `axis` was silently
|
||||
cast to integers.
|
||||
* {func}`jax.numpy.indices` now raises a `TypeError` if its dimensions
|
||||
are not of an integer type, matching the behavior of
|
||||
{func}`numpy.indices`. Previously non-integer dimensions were silently
|
||||
cast to integers.
|
||||
* {func}`jax.numpy.diag` now raises a `TypeError` if its `k` argument
|
||||
is not of an integer type, matching the behavior of
|
||||
{func}`numpy.diag`. Previously non-integer `k` was silently
|
||||
cast to integers.
|
||||
* Deprecations
|
||||
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
|
||||
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
|
||||
|
@ -774,7 +774,7 @@ def ravel(a, order="C"):
|
||||
@_wraps(np.ravel_multi_index)
|
||||
def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
|
||||
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
|
||||
dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims)
|
||||
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
|
||||
_check_arraylike("ravel_multi_index", *multi_index)
|
||||
for index in multi_index:
|
||||
if mode == 'raise':
|
||||
@ -1057,7 +1057,7 @@ The JAX version does not necessarily return a view of the input.
|
||||
def _split(op, ary, indices_or_sections, axis=0):
|
||||
_check_arraylike(op, ary)
|
||||
ary = asarray(ary)
|
||||
axis = core.concrete_or_error(int, axis, f"in jax.numpy.{op} argument `axis`")
|
||||
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
|
||||
size = ary.shape[axis]
|
||||
if isinstance(indices_or_sections, (tuple, list)):
|
||||
indices_or_sections = np.array(
|
||||
@ -1216,7 +1216,7 @@ def nonzero(a, *, size=None, fill_value=None):
|
||||
mask = a != 0
|
||||
if size is None:
|
||||
size = mask.sum()
|
||||
size = core.concrete_or_error(int, size,
|
||||
size = core.concrete_or_error(operator.index, size,
|
||||
"The size argument of jnp.nonzero must be statically specified "
|
||||
"to use jnp.nonzero within JAX transformations.")
|
||||
if a.size == 0 or size == 0:
|
||||
@ -2097,8 +2097,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
|
||||
axis: int = 0):
|
||||
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
|
||||
return _linspace(start, stop, int(num), endpoint, retstep, dtype,
|
||||
operator.index(axis))
|
||||
return _linspace(start, stop, num, endpoint, retstep, dtype, axis)
|
||||
|
||||
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
|
||||
def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
|
||||
@ -2160,8 +2159,7 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
|
||||
axis: int = 0):
|
||||
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace")
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace")
|
||||
return _logspace(start, stop, int(num), endpoint, base, dtype,
|
||||
operator.index(axis))
|
||||
return _logspace(start, stop, num, endpoint, base, dtype, axis)
|
||||
|
||||
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
|
||||
def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
|
||||
@ -2184,8 +2182,7 @@ def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
|
||||
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
|
||||
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace")
|
||||
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace")
|
||||
return _geomspace(start, stop, int(num), endpoint, dtype,
|
||||
operator.index(axis))
|
||||
return _geomspace(start, stop, num, endpoint, dtype, axis)
|
||||
|
||||
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
|
||||
def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
|
||||
@ -2267,7 +2264,7 @@ def ix_(*args):
|
||||
@_wraps(np.indices)
|
||||
def indices(dimensions, dtype=int32, sparse=False):
|
||||
dimensions = tuple(
|
||||
core.concrete_or_error(int, d, "dimensions argument of jnp.indices")
|
||||
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
|
||||
for d in dimensions)
|
||||
N = len(dimensions)
|
||||
output = []
|
||||
@ -2486,7 +2483,7 @@ def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1):
|
||||
|
||||
@_wraps(np.diag, lax_description=_ARRAY_VIEW_DOC)
|
||||
def diag(v, k=0):
|
||||
return _diag(v, int(k))
|
||||
return _diag(v, operator.index(k))
|
||||
|
||||
@partial(jit, static_argnames=('k',))
|
||||
def _diag(v, k):
|
||||
|
Loading…
x
Reference in New Issue
Block a user