From 0b93c46c71723ec175b47ee08f2071c815d1e041 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Oct 2021 16:28:36 -0700 Subject: [PATCH] jnp.unique: add fill_value for when size is not None --- CHANGELOG.md | 1 + jax/_src/numpy/lax_numpy.py | 25 ++++++++++++++++--------- tests/lax_numpy_test.py | 17 +++++++++-------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2971b89e9..66d06f023 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * An optimized C++ code-path improving the dispatch time for `pmap` is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable). + * `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`#8121`) ## jax 0.2.21 (Sept 23, 2021) * [GitHub diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ab75ca14d..9f3911e74 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5535,7 +5535,7 @@ def _unique1d_sorted_mask(ar, optional_indices=False): return aux, mask, perm def _unique1d(ar, return_index=False, return_inverse=False, - return_counts=False, size=None): + return_counts=False, size=None, fill_value=None): """ Find the unique elements of an array, ignoring shape. """ @@ -5545,9 +5545,16 @@ def _unique1d(ar, return_index=False, return_inverse=False, aux, mask, perm = _unique1d_sorted_mask(ar, return_index or return_inverse) ind = mask if size is None else nonzero(mask, size=size) - ret = (aux[ind],) + result = aux[ind] + if size is not None and fill_value is not None: + result = where(arange(size) >= mask.sum(), fill_value, result) + + ret = (result,) if return_index: - ret += (perm[ind],) + perm_ind = perm[ind] + if size is not None and fill_value is not None: + perm_ind = where(arange(size) >= mask.sum(), fill_value, perm_ind) + ret += (perm_ind,) if return_inverse: imask = cumsum(mask) - 1 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_)) @@ -5616,17 +5623,17 @@ def _unique_axis(ar, axis, return_index=False, return_inverse=False, _UNIQUE_DOC = """\ Because the size of the output of ``unique`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional `size` argument which -specifies the size of the data-dependent output arrays: it must be specified statically for -``jnp.unique`` to be traced. If specified, the first `size` unique elements will be returned; -if there are fewer unique elements than `size` indicates, the return value will be padded with -the minimum value in the input array. +specifies the size of the data-dependent output arrays: it must be specified statically +for ``jnp.unique`` to be traced. If specified, the first `size` unique elements will be +returned; if there are fewer unique elements than `size` indicates, the return value will +be padded with `fill_value`, which defaults to the minimum value in the input array. The `size` cannot currently be used with the `axis` argument.""" @_wraps(np.unique, skip_params=['axis'], lax_description=_UNIQUE_DOC) def unique(ar, return_index=False, return_inverse=False, - return_counts=False, axis: Optional[int] = None, *, size=None): + return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None): _check_arraylike("unique", ar) # TODO(jakevdp): call _check_arraylike on input. @@ -5641,7 +5648,7 @@ def unique(ar, return_index=False, return_inverse=False, ar = asarray(ar) if axis is None: - ret = _unique1d(ar, return_index, return_inverse, return_counts, size=size) + ret = _unique1d(ar, return_index, return_inverse, return_counts, size=size, fill_value=fill_value) else: axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()") ret = _unique_axis(ar, axis, return_index, return_inverse, return_counts) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index eb5105028..1fee83b8d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2312,13 +2312,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_size={}".format( - jtu.format_shape_dtype_string(shape, dtype), size), - "shape": shape, "dtype": dtype, "size": size} + {"testcase_name": "_{}_size={}_fill_value={}".format( + jtu.format_shape_dtype_string(shape, dtype), size, fill_value), + "shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value} for dtype in number_dtypes for size in [1, 5, 10] + for fill_value in [None, -1] for shape in nonempty_array_shapes)) - def testUniqueSize(self, shape, dtype, size): + def testUniqueSize(self, shape, dtype, size, fill_value): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] kwds = dict(return_index=True, return_inverse=True, return_counts=True) @@ -2329,12 +2330,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): u, ind, counts = u[:size], ind[:size], counts[:size] else: extra = size - len(u) - u = np.concatenate([u, np.full(extra, u[0], u.dtype)]) - ind = np.concatenate([ind, np.full(extra, ind[0], ind.dtype)]) - counts = np.concatenate([counts, np.zeros(extra, counts.dtype)]) + u = np.pad(u, (0, extra), constant_values=u[0] if fill_value is None else fill_value) + ind = np.pad(ind, (0, extra), constant_values=ind[0] if fill_value is None else fill_value) + counts = np.pad(counts, (0, extra), constant_values=0) return u, ind, inv, counts - jnp_fun = lambda x: jnp.unique(x, size=size, **kwds) + jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker)