mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jnp.unique: make return_inverse shape match NumPy 2.0
This commit is contained in:
parent
a20fc45d65
commit
fa6d3f26ff
@ -45,6 +45,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
`from jax.experimental import export`. The old way of importing will
|
`from jax.experimental import export`. The old way of importing will
|
||||||
continue to work for a deprecation period of 3 months.
|
continue to work for a deprecation period of 3 months.
|
||||||
* Added {func}`jax.scipy.stats.sem`.
|
* Added {func}`jax.scipy.stats.sem`.
|
||||||
|
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
|
||||||
|
reshaped to the dimension of the input, following a similar change to
|
||||||
|
{func}`numpy.unique` in NumPy 2.0.
|
||||||
|
|
||||||
* Deprecations & Removals
|
* Deprecations & Removals
|
||||||
* A number of previously deprecated functions have been removed, following a
|
* A number of previously deprecated functions have been removed, following a
|
||||||
|
@ -35,6 +35,7 @@ from jax._src.numpy.lax_numpy import (
|
|||||||
from jax._src.numpy.reductions import any, cumsum
|
from jax._src.numpy.reductions import any, cumsum
|
||||||
from jax._src.numpy.ufuncs import isnan
|
from jax._src.numpy.ufuncs import isnan
|
||||||
from jax._src.numpy.util import check_arraylike, _wraps
|
from jax._src.numpy.util import check_arraylike, _wraps
|
||||||
|
from jax._src.util import canonicalize_axis
|
||||||
from jax._src.typing import Array, ArrayLike
|
from jax._src.typing import Array, ArrayLike
|
||||||
|
|
||||||
|
|
||||||
@ -256,6 +257,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
|||||||
"""
|
"""
|
||||||
Find the unique elements of an array along a particular axis.
|
Find the unique elements of an array along a particular axis.
|
||||||
"""
|
"""
|
||||||
|
axis = canonicalize_axis(axis, ar.ndim)
|
||||||
|
|
||||||
if ar.shape[axis] == 0 and size and fill_value is None:
|
if ar.shape[axis] == 0 and size and fill_value is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
|
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
|
||||||
@ -289,6 +292,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
|||||||
inv_idx = inv_idx.at[perm].set(imask)
|
inv_idx = inv_idx.at[perm].set(imask)
|
||||||
else:
|
else:
|
||||||
inv_idx = zeros(ar.shape[axis], dtype=int)
|
inv_idx = zeros(ar.shape[axis], dtype=int)
|
||||||
|
if ar.ndim > 1:
|
||||||
|
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
|
||||||
ret += (inv_idx,)
|
ret += (inv_idx,)
|
||||||
if return_counts:
|
if return_counts:
|
||||||
if aux.size:
|
if aux.size:
|
||||||
@ -332,12 +337,18 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
|||||||
size = core.concrete_or_error(operator.index, size,
|
size = core.concrete_or_error(operator.index, size,
|
||||||
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
|
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
|
||||||
arr = asarray(ar)
|
arr = asarray(ar)
|
||||||
|
arr_shape = arr.shape
|
||||||
if axis is None:
|
if axis is None:
|
||||||
axis = 0
|
axis_int: int = 0
|
||||||
arr = arr.flatten()
|
arr = arr.flatten()
|
||||||
axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
|
else:
|
||||||
return _unique(arr, axis_int, return_index, return_inverse,
|
axis_int = canonicalize_axis(axis, arr.ndim)
|
||||||
return_counts, equal_nan=equal_nan, size=size, fill_value=fill_value)
|
result = _unique(arr, axis_int, return_index, return_inverse, return_counts,
|
||||||
|
equal_nan=equal_nan, size=size, fill_value=fill_value)
|
||||||
|
if return_inverse and axis is None:
|
||||||
|
idx = 2 if return_index else 1
|
||||||
|
result = (*result[:idx], result[idx].reshape(arr_shape), *result[idx + 1:])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class _UniqueAllResult(NamedTuple):
|
class _UniqueAllResult(NamedTuple):
|
||||||
@ -362,7 +373,6 @@ def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
|
|||||||
check_arraylike("unique_all", x)
|
check_arraylike("unique_all", x)
|
||||||
values, indices, inverse_indices, counts = unique(
|
values, indices, inverse_indices, counts = unique(
|
||||||
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False)
|
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False)
|
||||||
inverse_indices = inverse_indices.reshape(np.shape(x))
|
|
||||||
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
|
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
|
||||||
|
|
||||||
|
|
||||||
@ -377,7 +387,6 @@ def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
|
|||||||
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
|
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
|
||||||
check_arraylike("unique_inverse", x)
|
check_arraylike("unique_inverse", x)
|
||||||
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
|
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
|
||||||
inverse_indices = inverse_indices.reshape(np.shape(x))
|
|
||||||
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
|
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1453,6 +1453,9 @@ def _unique_indices_unbatched(indices, *, shape, return_inverse=False,
|
|||||||
# TODO: check if `indices_sorted` is True.
|
# TODO: check if `indices_sorted` is True.
|
||||||
out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index,
|
out = _unique(indices, axis=0, return_inverse=return_inverse, return_index=return_index,
|
||||||
return_true_size=return_true_size, size=props.nse, fill_value=fill_value)
|
return_true_size=return_true_size, size=props.nse, fill_value=fill_value)
|
||||||
|
if return_inverse:
|
||||||
|
idx = 2 if return_index else 1
|
||||||
|
out = (*out[:idx], out[idx].ravel(), *out[idx + 1:])
|
||||||
if return_true_size:
|
if return_true_size:
|
||||||
nse = out[-1]
|
nse = out[-1]
|
||||||
nse = nse - (indices == fill_value).any().astype(nse.dtype)
|
nse = nse - (indices == fill_value).any().astype(nse.dtype)
|
||||||
|
@ -87,6 +87,23 @@ python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
|||||||
# uint64 is problematic because with any uint type it promotes to float:
|
# uint64 is problematic because with any uint type it promotes to float:
|
||||||
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
|
int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64]
|
||||||
|
|
||||||
|
def np_unique_backport(ar, return_index=False, return_inverse=False, return_counts=False,
|
||||||
|
axis=None, **kwds):
|
||||||
|
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
|
||||||
|
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
|
||||||
|
return_counts=return_counts, axis=axis, **kwds)
|
||||||
|
if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse:
|
||||||
|
return result
|
||||||
|
|
||||||
|
idx = 2 if return_index else 1
|
||||||
|
inverse_indices = result[idx]
|
||||||
|
if axis is None:
|
||||||
|
inverse_indices = inverse_indices.reshape(np.shape(ar))
|
||||||
|
else:
|
||||||
|
inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis])
|
||||||
|
return (*result[:idx], inverse_indices, *result[idx + 1:])
|
||||||
|
|
||||||
|
|
||||||
def _indexer_with_default_outputs(indexer, use_defaults=True):
|
def _indexer_with_default_outputs(indexer, use_defaults=True):
|
||||||
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
|
"""Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs"""
|
||||||
class Indexer:
|
class Indexer:
|
||||||
@ -1818,7 +1835,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
extra_args = (return_index, return_inverse, return_counts)
|
extra_args = (return_index, return_inverse, return_counts)
|
||||||
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
|
use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False
|
||||||
np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults)
|
np_fun = jtu.with_jax_dtype_defaults(lambda x: np_unique_backport(x, *extra_args, axis=axis), use_defaults)
|
||||||
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
|
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
|
|
||||||
@ -1827,10 +1844,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
rng = jtu.rand_some_equal(self.rng())
|
rng = jtu.rand_some_equal(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
if jtu.numpy_version() < (2, 0, 0):
|
if jtu.numpy_version() < (2, 0, 0):
|
||||||
def np_fun(x):
|
np_fun = partial(np_unique_backport, return_index=True, return_inverse=True, return_counts=True)
|
||||||
values, indices, inverse_indices, counts = np.unique(
|
|
||||||
x, return_index=True, return_inverse=True, return_counts=True)
|
|
||||||
return values, indices, inverse_indices.reshape(np.shape(x)), counts
|
|
||||||
else:
|
else:
|
||||||
np_fun = np.unique_all
|
np_fun = np.unique_all
|
||||||
self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker)
|
self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker)
|
||||||
@ -1850,9 +1864,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
rng = jtu.rand_some_equal(self.rng())
|
rng = jtu.rand_some_equal(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
if jtu.numpy_version() < (2, 0, 0):
|
if jtu.numpy_version() < (2, 0, 0):
|
||||||
def np_fun(x):
|
np_fun = partial(np_unique_backport, return_inverse=True)
|
||||||
values, inverse_indices = np.unique(x, return_inverse=True)
|
|
||||||
return values, inverse_indices.reshape(np.shape(x))
|
|
||||||
else:
|
else:
|
||||||
np_fun = np.unique_inverse
|
np_fun = np.unique_inverse
|
||||||
self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker)
|
self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker)
|
||||||
@ -1888,7 +1900,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
|
|
||||||
@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
|
@partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True))
|
||||||
def np_fun(x, fill_value=fill_value):
|
def np_fun(x, fill_value=fill_value):
|
||||||
u, ind, inv, counts = np.unique(x, **kwds)
|
u, ind, inv, counts = np_unique_backport(x, **kwds)
|
||||||
axis = kwds['axis']
|
axis = kwds['axis']
|
||||||
if axis is None:
|
if axis is None:
|
||||||
x = x.ravel()
|
x = x.ravel()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user