mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
jnp.unique: add support for axis argument
This commit is contained in:
parent
90d606fe25
commit
bb543f2b5b
@ -15,7 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
keyword arguments. A new `static_argnames` option has been added to specify
|
||||
keyword arguments as static.
|
||||
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
|
||||
be used within `jit` ({jax-issue}`6501`)
|
||||
be used within `jit` ({jax-issue}`#6501`)
|
||||
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
|
||||
* Breaking changes:
|
||||
* The following function names have changed. There are still aliases, so this
|
||||
should not break existing code, but the aliases will eventually be removed
|
||||
|
@ -4513,27 +4513,19 @@ def _unique1d_sorted_mask(ar, optional_indices=False):
|
||||
perm = ar.argsort()
|
||||
aux = ar[perm]
|
||||
else:
|
||||
perm = np.empty(0, dtype=int)
|
||||
aux = ar.sort()
|
||||
|
||||
mask = ones(aux.shape, dtype=bool_).at[1:].set(aux[1:] != aux[:-1])
|
||||
|
||||
if optional_indices:
|
||||
return aux, mask, perm
|
||||
else:
|
||||
return aux, mask
|
||||
|
||||
def _unique1d(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False):
|
||||
"""
|
||||
Find the unique elements of an array, ignoring shape.
|
||||
"""
|
||||
|
||||
optional_indices = return_index or return_inverse
|
||||
|
||||
if optional_indices:
|
||||
aux, mask, perm = _unique1d_sorted_mask(ar, optional_indices)
|
||||
else:
|
||||
aux, mask = _unique1d_sorted_mask(ar, optional_indices)
|
||||
aux, mask, perm = _unique1d_sorted_mask(ar, return_index or return_inverse)
|
||||
|
||||
ret = (aux[mask],)
|
||||
if return_index:
|
||||
@ -4541,11 +4533,61 @@ def _unique1d(ar, return_index=False, return_inverse=False,
|
||||
if return_inverse:
|
||||
imask = cumsum(mask) - 1
|
||||
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
|
||||
inv_idx = ops.index_update(inv_idx, perm, imask)
|
||||
inv_idx = inv_idx.at[perm].set(imask)
|
||||
ret += (inv_idx,)
|
||||
if return_counts:
|
||||
idx = concatenate(nonzero(mask) + (array([mask.size]),))
|
||||
ret += (diff(idx),)
|
||||
|
||||
return ret
|
||||
|
||||
@partial(jit, static_argnums=1)
|
||||
def _unique_axis_sorted_mask(ar, axis):
|
||||
aux = moveaxis(ar, axis, 0)
|
||||
size, *out_shape = aux.shape
|
||||
aux = aux.reshape(size, _prod(out_shape)).T
|
||||
if aux.shape[0] == 0:
|
||||
perm = zeros(1, dtype=int)
|
||||
else:
|
||||
perm = lexsort(aux[::-1])
|
||||
aux = aux[:, perm]
|
||||
if aux.size:
|
||||
mask = ones(size, dtype=bool).at[1:].set(any(aux[:, 1:] != aux[:, :-1], 0))
|
||||
else:
|
||||
mask = zeros(size, dtype=bool)
|
||||
return aux, mask, perm, out_shape
|
||||
|
||||
def _unique_axis(ar, axis, return_index=False, return_inverse=False,
|
||||
return_counts=False):
|
||||
"""
|
||||
Find the unique elements of an array along a particular axis.
|
||||
"""
|
||||
aux, mask, perm, out_shape = _unique_axis_sorted_mask(ar, axis)
|
||||
result = moveaxis(aux[:, mask].T.reshape(mask.sum() or aux.shape[1], *out_shape), 0, axis)
|
||||
|
||||
ret = (result,)
|
||||
if return_index:
|
||||
if aux.size:
|
||||
ret += (perm[mask],)
|
||||
else:
|
||||
ret += (perm,)
|
||||
if return_inverse:
|
||||
if aux.size:
|
||||
imask = cumsum(mask) - 1
|
||||
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
|
||||
inv_idx = inv_idx.at[perm].set(imask)
|
||||
else:
|
||||
inv_idx = zeros(ar.shape[axis], dtype=int)
|
||||
ret += (inv_idx,)
|
||||
if return_counts:
|
||||
if aux.size:
|
||||
idx = concatenate(nonzero(mask) + (array([mask.size]),))
|
||||
ret += (diff(idx),)
|
||||
elif ar.shape[axis]:
|
||||
ret += (array([ar.shape[axis]]),)
|
||||
else:
|
||||
ret += (empty(0, dtype=int),)
|
||||
|
||||
return ret
|
||||
|
||||
@_wraps(np.unique, skip_params=['axis'])
|
||||
@ -4553,16 +4595,12 @@ def unique(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, axis: Optional[int] = None):
|
||||
ar = core.concrete_or_error(asarray, ar, "The error arose in jnp.unique()")
|
||||
|
||||
if axis is not None:
|
||||
raise NotImplementedError(
|
||||
"np.unique is not implemented for the axis argument")
|
||||
|
||||
if axis is None:
|
||||
ret = _unique1d(ar, return_index, return_inverse, return_counts)
|
||||
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return ret
|
||||
ret = _unique_axis(ar, axis, return_index, return_inverse, return_counts)
|
||||
|
||||
return ret[0] if len(ret) == 1 else ret
|
||||
|
||||
### Indexing
|
||||
|
||||
|
@ -2119,22 +2119,25 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_ind={}_inv={}_count={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
{"testcase_name": "_{}_axis={}_ind={}_inv={}_count={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis,
|
||||
return_index, return_inverse, return_counts),
|
||||
"shape": shape, "dtype": dtype,
|
||||
"shape": shape, "dtype": dtype, "axis": axis,
|
||||
"return_index": return_index, "return_inverse": return_inverse,
|
||||
"return_counts": return_counts}
|
||||
for dtype in number_dtypes
|
||||
for shape in all_shapes
|
||||
for axis in [None] + list(range(len(shape)))
|
||||
for return_index in [False, True]
|
||||
for return_inverse in [False, True]
|
||||
for return_counts in [False, True]))
|
||||
def testUnique(self, shape, dtype, return_index, return_inverse, return_counts):
|
||||
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
|
||||
if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0:
|
||||
self.skipTest("zero-sized axis in unique leads to error in older numpy.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts)
|
||||
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts)
|
||||
np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts, axis=axis)
|
||||
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts, axis=axis)
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user