From bb543f2b5b69c39d2c2ccd442b0305f241893b9c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 21 Apr 2021 16:00:14 -0700 Subject: [PATCH] jnp.unique: add support for axis argument --- CHANGELOG.md | 3 +- jax/_src/numpy/lax_numpy.py | 80 +++++++++++++++++++++++++++---------- tests/lax_numpy_test.py | 15 ++++--- 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f2ad693a..8f4534e30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. keyword arguments. A new `static_argnames` option has been added to specify keyword arguments as static. * {func}`jax.nonzero` has a new optional `size` argument that allows it to - be used within `jit` ({jax-issue}`6501`) + be used within `jit` ({jax-issue}`#6501`) + * {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`#6532`). * Breaking changes: * The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 2b3751177..567ec6000 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4513,27 +4513,19 @@ def _unique1d_sorted_mask(ar, optional_indices=False): perm = ar.argsort() aux = ar[perm] else: + perm = np.empty(0, dtype=int) aux = ar.sort() mask = ones(aux.shape, dtype=bool_).at[1:].set(aux[1:] != aux[:-1]) - if optional_indices: - return aux, mask, perm - else: - return aux, mask + return aux, mask, perm def _unique1d(ar, return_index=False, return_inverse=False, return_counts=False): """ Find the unique elements of an array, ignoring shape. """ - - optional_indices = return_index or return_inverse - - if optional_indices: - aux, mask, perm = _unique1d_sorted_mask(ar, optional_indices) - else: - aux, mask = _unique1d_sorted_mask(ar, optional_indices) + aux, mask, perm = _unique1d_sorted_mask(ar, return_index or return_inverse) ret = (aux[mask],) if return_index: @@ -4541,11 +4533,61 @@ def _unique1d(ar, return_index=False, return_inverse=False, if return_inverse: imask = cumsum(mask) - 1 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_)) - inv_idx = ops.index_update(inv_idx, perm, imask) + inv_idx = inv_idx.at[perm].set(imask) ret += (inv_idx,) if return_counts: idx = concatenate(nonzero(mask) + (array([mask.size]),)) ret += (diff(idx),) + + return ret + +@partial(jit, static_argnums=1) +def _unique_axis_sorted_mask(ar, axis): + aux = moveaxis(ar, axis, 0) + size, *out_shape = aux.shape + aux = aux.reshape(size, _prod(out_shape)).T + if aux.shape[0] == 0: + perm = zeros(1, dtype=int) + else: + perm = lexsort(aux[::-1]) + aux = aux[:, perm] + if aux.size: + mask = ones(size, dtype=bool).at[1:].set(any(aux[:, 1:] != aux[:, :-1], 0)) + else: + mask = zeros(size, dtype=bool) + return aux, mask, perm, out_shape + +def _unique_axis(ar, axis, return_index=False, return_inverse=False, + return_counts=False): + """ + Find the unique elements of an array along a particular axis. + """ + aux, mask, perm, out_shape = _unique_axis_sorted_mask(ar, axis) + result = moveaxis(aux[:, mask].T.reshape(mask.sum() or aux.shape[1], *out_shape), 0, axis) + + ret = (result,) + if return_index: + if aux.size: + ret += (perm[mask],) + else: + ret += (perm,) + if return_inverse: + if aux.size: + imask = cumsum(mask) - 1 + inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_)) + inv_idx = inv_idx.at[perm].set(imask) + else: + inv_idx = zeros(ar.shape[axis], dtype=int) + ret += (inv_idx,) + if return_counts: + if aux.size: + idx = concatenate(nonzero(mask) + (array([mask.size]),)) + ret += (diff(idx),) + elif ar.shape[axis]: + ret += (array([ar.shape[axis]]),) + else: + ret += (empty(0, dtype=int),) + return ret @_wraps(np.unique, skip_params=['axis']) @@ -4553,16 +4595,12 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis: Optional[int] = None): ar = core.concrete_or_error(asarray, ar, "The error arose in jnp.unique()") - if axis is not None: - raise NotImplementedError( - "np.unique is not implemented for the axis argument") - - ret = _unique1d(ar, return_index, return_inverse, return_counts) - - if len(ret) == 1: - return ret[0] + if axis is None: + ret = _unique1d(ar, return_index, return_inverse, return_counts) else: - return ret + ret = _unique_axis(ar, axis, return_index, return_inverse, return_counts) + + return ret[0] if len(ret) == 1 else ret ### Indexing diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 1565fb41e..e619a0e2c 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2119,22 +2119,25 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_ind={}_inv={}_count={}".format( - jtu.format_shape_dtype_string(shape, dtype), + {"testcase_name": "_{}_axis={}_ind={}_inv={}_count={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, return_index, return_inverse, return_counts), - "shape": shape, "dtype": dtype, + "shape": shape, "dtype": dtype, "axis": axis, "return_index": return_index, "return_inverse": return_inverse, "return_counts": return_counts} for dtype in number_dtypes for shape in all_shapes + for axis in [None] + list(range(len(shape))) for return_index in [False, True] for return_inverse in [False, True] for return_counts in [False, True])) - def testUnique(self, shape, dtype, return_index, return_inverse, return_counts): + def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): + if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0: + self.skipTest("zero-sized axis in unique leads to error in older numpy.") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts) - jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts) + np_fun = lambda x: np.unique(x, return_index, return_inverse, return_counts, axis=axis) + jnp_fun = lambda x: jnp.unique(x, return_index, return_inverse, return_counts, axis=axis) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list(