jnp.unique: make return_inverse shape match NumPy 2.0

This commit is contained in:
Jake VanderPlas 2024-01-11 12:15:52 -08:00
parent a20fc45d65
commit fa6d3f26ff
4 changed files with 42 additions and 15 deletions

View File

@ -45,6 +45,9 @@ Remember to align the itemized text with the first line of an item within a list
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Added {func}`jax.scipy.stats.sem`.
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
reshaped to the dimension of the input, following a similar change to
{func}`numpy.unique` in NumPy 2.0.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a

View File

@ -35,6 +35,7 @@ from jax._src.numpy.lax_numpy import (
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import check_arraylike, _wraps
from jax._src.util import canonicalize_axis
from jax._src.typing import Array, ArrayLike
@ -256,6 +257,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
"""
Find the unique elements of an array along a particular axis.
"""
axis = canonicalize_axis(axis, ar.ndim)
if ar.shape[axis] == 0 and size and fill_value is None:
raise ValueError(
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
@ -289,6 +292,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
inv_idx = inv_idx.at[perm].set(imask)
else:
inv_idx = zeros(ar.shape[axis], dtype=int)
if ar.ndim > 1:
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
ret += (inv_idx,)
if return_counts:
if aux.size:
@ -332,12 +337,18 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
size = core.concrete_or_error(operator.index, size,
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
arr = asarray(ar)
arr_shape = arr.shape
if axis is None:
axis = 0
axis_int: int = 0
arr = arr.flatten()
axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
return _unique(arr, axis_int, return_index, return_inverse,
return_counts, equal_nan=equal_nan, size=size, fill_value=fill_value)
else:
axis_int = canonicalize_axis(axis, arr.ndim)
result = _unique(arr, axis_int, return_index, return_inverse, return_counts,
equal_nan=equal_nan, size=size, fill_value=fill_value)
if return_inverse and axis is None:
idx = 2 if return_index else 1
result = (*result[:idx], result[idx].reshape(arr_shape), *result[idx + 1:])
return result
class _UniqueAllResult(NamedTuple):
@ -362,7 +373,6 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
check_arraylike("unique_all", x)
values, indices, inverse_indices, counts = unique(
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False)
inverse_indices = inverse_indices.reshape(np.shape(x))
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
@ -377,7 +387,6 @@ def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
check_arraylike("unique_inverse", x)
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
inverse_indices = inverse_indices.reshape(np.shape(x))
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)

View File

@ -1453,6 +1453,9 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
# TODO: check if `indices_sorted` is True.
out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index,
return_true_size=return_true_size, size=props.nse, fill_value=fill_value)
if return_inverse:
idx = 2 if return_index else 1
out = (*out[:idx], out[idx].ravel(), *out[idx + 1:])
if return_true_size:
nse = out[-1]
nse = nse - (indices == fill_value).any().astype(nse.dtype)

View File

@ -87,6 +87,23 @@ python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
# uint64 is problematic because with any uint type it promotes to float:
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
axis=None, **kwds):
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
return_counts=return_counts, axis=axis, **kwds)
if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse:
return result
idx = 2 if return_index else 1
inverse_indices = result[idx]
if axis is None:
inverse_indices = inverse_indices.reshape(np.shape(ar))
else:
inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis])
return (*result[:idx], inverse_indices, *result[idx + 1:])
def _indexer_with_default_outputs(indexer, use_defaults=True):
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
class Indexer:
@ -1818,7 +1835,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
args_maker = lambda: [rng(shape, dtype)]
extra_args = (return_index, return_inverse, return_counts)
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults)
np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults)
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@ -1827,10 +1844,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fun(x):
values, indices, inverse_indices, counts = np.unique(
x, return_index=True, return_inverse=True, return_counts=True)
return values, indices, inverse_indices.reshape(np.shape(x)), counts
np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True)
else:
np_fun = np.unique_all
self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker)
@ -1850,9 +1864,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fun(x):
values, inverse_indices = np.unique(x, return_inverse=True)
return values, inverse_indices.reshape(np.shape(x))
np_fun = partial(np_unique_backport, return_inverse=True)
else:
np_fun = np.unique_inverse
self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker)
@ -1888,7 +1900,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
def np_fun(x, fill_value=fill_value):
u, ind, inv, counts = np.unique(x, **kwds)
u, ind, inv, counts = np_unique_backport(x, **kwds)
axis = kwds['axis']
if axis is None:
x = x.ravel()