jnp.unique: add support for axis argument

This commit is contained in:
Jake VanderPlas 2021-04-21 16:00:14 -07:00
parent 90d606fe25
commit bb543f2b5b
3 changed files with 70 additions and 28 deletions

View File

@ -15,7 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
keyword arguments. A new `static_argnames` option has been added to specify
keyword arguments as static.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`6501`)
be used within `jit` ({jax-issue}`#6501`)
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`).
* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed

View File

@ -4513,27 +4513,19 @@ def _unique1d_sorted_mask(ar, optional_indices=False):
perm = ar.argsort()
aux = ar[perm]
else:
perm = np.empty(0, dtype=int)
aux = ar.sort()
mask = ones(aux.shape, dtype=bool_).at[1:].set(aux[1:] != aux[:-1])
if optional_indices:
return aux, mask, perm
else:
return aux, mask
def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False):
"""
Find the unique elements of an array, ignoring shape.
"""
optional_indices = return_index or return_inverse
if optional_indices:
aux, mask, perm = _unique1d_sorted_mask(ar, optional_indices)
else:
aux, mask = _unique1d_sorted_mask(ar, optional_indices)
aux, mask, perm = _unique1d_sorted_mask(ar, return_index or return_inverse)
ret = (aux[mask],)
if return_index:
@ -4541,11 +4533,61 @@ def _unique1d(ar, return_index=False, return_inverse=False,
if return_inverse:
imask = cumsum(mask) - 1
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
inv_idx = ops.index_update(inv_idx, perm, imask)
inv_idx = inv_idx.at[perm].set(imask)
ret += (inv_idx,)
if return_counts:
idx = concatenate(nonzero(mask) + (array([mask.size]),))
ret += (diff(idx),)
return ret
@partial(jit, static_argnums=1)
def _unique_axis_sorted_mask(ar, axis):
aux = moveaxis(ar, axis, 0)
size, *out_shape = aux.shape
aux = aux.reshape(size, _prod(out_shape)).T
if aux.shape[0] == 0:
perm = zeros(1, dtype=int)
else:
perm = lexsort(aux[::-1])
aux = aux[:, perm]
if aux.size:
mask = ones(size, dtype=bool).at[1:].set(any(aux[:, 1:] != aux[:, :-1], 0))
else:
mask = zeros(size, dtype=bool)
return aux, mask, perm, out_shape
def _unique_axis(ar, axis, return_index=False, return_inverse=False,
return_counts=False):
"""
Find the unique elements of an array along a particular axis.
"""
aux, mask, perm, out_shape = _unique_axis_sorted_mask(ar, axis)
result = moveaxis(aux[:, mask].T.reshape(mask.sum() or aux.shape[1], *out_shape), 0, axis)
ret = (result,)
if return_index:
if aux.size:
ret += (perm[mask],)
else:
ret += (perm,)
if return_inverse:
if aux.size:
imask = cumsum(mask) - 1
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
inv_idx = inv_idx.at[perm].set(imask)
else:
inv_idx = zeros(ar.shape[axis], dtype=int)
ret += (inv_idx,)
if return_counts:
if aux.size:
idx = concatenate(nonzero(mask) + (array([mask.size]),))
ret += (diff(idx),)
elif ar.shape[axis]:
ret += (array([ar.shape[axis]]),)
else:
ret += (empty(0, dtype=int),)
return ret
@_wraps(np.unique, skip_params=['axis'])
@ -4553,16 +4595,12 @@ def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis: Optional[int] = None):
ar = core.concrete_or_error(asarray, ar, "The error arose in jnp.unique()")
if axis is not None:
raise NotImplementedError(
"np.unique is not implemented for the axis argument")
if axis is None:
ret = _unique1d(ar, return_index, return_inverse, return_counts)
if len(ret) == 1:
return ret[0]
else:
return ret
ret = _unique_axis(ar, axis, return_index, return_inverse, return_counts)
return ret[0] if len(ret) == 1 else ret
### Indexing

View File

@ -2119,22 +2119,25 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_ind={}_inv={}_count={}".format(
jtu.format_shape_dtype_string(shape, dtype),
{"testcase_name": "_{}_axis={}_ind={}_inv={}_count={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis,
return_index, return_inverse, return_counts),
"shape": shape, "dtype": dtype,
"shape": shape, "dtype": dtype, "axis": axis,
"return_index": return_index, "return_inverse": return_inverse,
"return_counts": return_counts}
for dtype in number_dtypes
for shape in all_shapes
for axis in [None] + list(range(len(shape)))
for return_index in [False, True]
for return_inverse in [False, True]
for return_counts in [False, True]))
def testUnique(self, shape, dtype, return_index, return_inverse, return_counts):
def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts):
if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0:
self.skipTest("zero-sized axis in unique leads to error in older numpy.")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts)
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts)
np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts, axis=axis)
jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(