mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jnp.unique: make return_inverse shape match NumPy 2.0
This commit is contained in:
parent
a20fc45d65
commit
fa6d3f26ff
@ -45,6 +45,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
`from jax.experimental import export`. The old way of importing will
|
||||
continue to work for a deprecation period of 3 months.
|
||||
* Added {func}`jax.scipy.stats.sem`.
|
||||
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
|
||||
reshaped to the dimension of the input, following a similar change to
|
||||
{func}`numpy.unique` in NumPy 2.0.
|
||||
|
||||
* Deprecations & Removals
|
||||
* A number of previously deprecated functions have been removed, following a
|
||||
|
@ -35,6 +35,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
from jax._src.numpy.reductions import any, cumsum
|
||||
from jax._src.numpy.ufuncs import isnan
|
||||
from jax._src.numpy.util import check_arraylike, _wraps
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@ -256,6 +257,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
"""
|
||||
Find the unique elements of an array along a particular axis.
|
||||
"""
|
||||
axis = canonicalize_axis(axis, ar.ndim)
|
||||
|
||||
if ar.shape[axis] == 0 and size and fill_value is None:
|
||||
raise ValueError(
|
||||
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
|
||||
@ -289,6 +292,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
inv_idx = inv_idx.at[perm].set(imask)
|
||||
else:
|
||||
inv_idx = zeros(ar.shape[axis], dtype=int)
|
||||
if ar.ndim > 1:
|
||||
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
|
||||
ret += (inv_idx,)
|
||||
if return_counts:
|
||||
if aux.size:
|
||||
@ -332,12 +337,18 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
size = core.concrete_or_error(operator.index, size,
|
||||
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
|
||||
arr = asarray(ar)
|
||||
arr_shape = arr.shape
|
||||
if axis is None:
|
||||
axis = 0
|
||||
axis_int: int = 0
|
||||
arr = arr.flatten()
|
||||
axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
|
||||
return _unique(arr, axis_int, return_index, return_inverse,
|
||||
return_counts, equal_nan=equal_nan, size=size, fill_value=fill_value)
|
||||
else:
|
||||
axis_int = canonicalize_axis(axis, arr.ndim)
|
||||
result = _unique(arr, axis_int, return_index, return_inverse, return_counts,
|
||||
equal_nan=equal_nan, size=size, fill_value=fill_value)
|
||||
if return_inverse and axis is None:
|
||||
idx = 2 if return_index else 1
|
||||
result = (*result[:idx], result[idx].reshape(arr_shape), *result[idx + 1:])
|
||||
return result
|
||||
|
||||
|
||||
class _UniqueAllResult(NamedTuple):
|
||||
@ -362,7 +373,6 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
|
||||
check_arraylike("unique_all", x)
|
||||
values, indices, inverse_indices, counts = unique(
|
||||
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False)
|
||||
inverse_indices = inverse_indices.reshape(np.shape(x))
|
||||
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
|
||||
|
||||
|
||||
@ -377,7 +387,6 @@ def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
|
||||
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
|
||||
check_arraylike("unique_inverse", x)
|
||||
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
|
||||
inverse_indices = inverse_indices.reshape(np.shape(x))
|
||||
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
|
||||
|
||||
|
||||
|
@ -1453,6 +1453,9 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
|
||||
# TODO: check if `indices_sorted` is True.
|
||||
out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index,
|
||||
return_true_size=return_true_size, size=props.nse, fill_value=fill_value)
|
||||
if return_inverse:
|
||||
idx = 2 if return_index else 1
|
||||
out = (*out[:idx], out[idx].ravel(), *out[idx + 1:])
|
||||
if return_true_size:
|
||||
nse = out[-1]
|
||||
nse = nse - (indices == fill_value).any().astype(nse.dtype)
|
||||
|
@ -87,6 +87,23 @@ python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
||||
# uint64 is problematic because with any uint type it promotes to float:
|
||||
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
|
||||
|
||||
def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
|
||||
axis=None, **kwds):
|
||||
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
|
||||
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
|
||||
return_counts=return_counts, axis=axis, **kwds)
|
||||
if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse:
|
||||
return result
|
||||
|
||||
idx = 2 if return_index else 1
|
||||
inverse_indices = result[idx]
|
||||
if axis is None:
|
||||
inverse_indices = inverse_indices.reshape(np.shape(ar))
|
||||
else:
|
||||
inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis])
|
||||
return (*result[:idx], inverse_indices, *result[idx + 1:])
|
||||
|
||||
|
||||
def _indexer_with_default_outputs(indexer, use_defaults=True):
|
||||
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
|
||||
class Indexer:
|
||||
@ -1818,7 +1835,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
extra_args = (return_index, return_inverse, return_counts)
|
||||
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
|
||||
np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults)
|
||||
np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults)
|
||||
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@ -1827,10 +1844,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_some_equal(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
if jtu.numpy_version() < (2, 0, 0):
|
||||
def np_fun(x):
|
||||
values, indices, inverse_indices, counts = np.unique(
|
||||
x, return_index=True, return_inverse=True, return_counts=True)
|
||||
return values, indices, inverse_indices.reshape(np.shape(x)), counts
|
||||
np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True)
|
||||
else:
|
||||
np_fun = np.unique_all
|
||||
self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker)
|
||||
@ -1850,9 +1864,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_some_equal(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
if jtu.numpy_version() < (2, 0, 0):
|
||||
def np_fun(x):
|
||||
values, inverse_indices = np.unique(x, return_inverse=True)
|
||||
return values, inverse_indices.reshape(np.shape(x))
|
||||
np_fun = partial(np_unique_backport, return_inverse=True)
|
||||
else:
|
||||
np_fun = np.unique_inverse
|
||||
self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker)
|
||||
@ -1888,7 +1900,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
|
||||
def np_fun(x, fill_value=fill_value):
|
||||
u, ind, inv, counts = np.unique(x, **kwds)
|
||||
u, ind, inv, counts = np_unique_backport(x, **kwds)
|
||||
axis = kwds['axis']
|
||||
if axis is None:
|
||||
x = x.ravel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user