jnp.unique: add fill_value for when size is not None

This commit is contained in:
Jake VanderPlas 2021-10-06 16:28:36 -07:00
parent 3c117fd6ed
commit 0b93c46c71
3 changed files with 26 additions and 17 deletions

View File

@ -38,6 +38,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* An optimized C++ code-path improving the dispatch time for `pmap` is now the
default when using jaxlib 0.1.72 or newer. The feature can be disabled using
the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
* `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`#8121`)
## jax 0.2.21 (Sept 23, 2021)
* [GitHub

View File

@ -5535,7 +5535,7 @@ def _unique1d_sorted_mask(ar, optional_indices=False):
return aux, mask, perm
def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False, size=None):
return_counts=False, size=None, fill_value=None):
"""
Find the unique elements of an array, ignoring shape.
"""
@ -5545,9 +5545,16 @@ def _unique1d(ar, return_index=False, return_inverse=False,
aux, mask, perm = _unique1d_sorted_mask(ar, return_index or return_inverse)
ind = mask if size is None else nonzero(mask, size=size)
ret = (aux[ind],)
result = aux[ind]
if size is not None and fill_value is not None:
result = where(arange(size) >= mask.sum(), fill_value, result)
ret = (result,)
if return_index:
ret += (perm[ind],)
perm_ind = perm[ind]
if size is not None and fill_value is not None:
perm_ind = where(arange(size) >= mask.sum(), fill_value, perm_ind)
ret += (perm_ind,)
if return_inverse:
imask = cumsum(mask) - 1
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
@ -5616,17 +5623,17 @@ def _unique_axis(ar, axis, return_index=False, return_inverse=False,
_UNIQUE_DOC = """\
Because the size of the output of ``unique`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the data-dependent output arrays: it must be specified statically for
``jnp.unique`` to be traced. If specified, the first `size` unique elements will be returned;
if there are fewer unique elements than `size` indicates, the return value will be padded with
the minimum value in the input array.
specifies the size of the data-dependent output arrays: it must be specified statically
for ``jnp.unique`` to be traced. If specified, the first `size` unique elements will be
returned; if there are fewer unique elements than `size` indicates, the return value will
be padded with `fill_value`, which defaults to the minimum value in the input array.
The `size` cannot currently be used with the `axis` argument."""
@_wraps(np.unique, skip_params=['axis'], lax_description=_UNIQUE_DOC)
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis: Optional[int] = None, *, size=None):
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
_check_arraylike("unique", ar)
# TODO(jakevdp): call _check_arraylike on input.
@ -5641,7 +5648,7 @@ def unique(ar, return_index=False, return_inverse=False,
ar = asarray(ar)
if axis is None:
ret = _unique1d(ar, return_index, return_inverse, return_counts, size=size)
ret = _unique1d(ar, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
else:
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
ret = _unique_axis(ar, axis, return_index, return_inverse, return_counts)

View File

@ -2312,13 +2312,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_size={}".format(
jtu.format_shape_dtype_string(shape, dtype), size),
"shape": shape, "dtype": dtype, "size": size}
{"testcase_name": "_{}_size={}_fill_value={}".format(
jtu.format_shape_dtype_string(shape, dtype), size, fill_value),
"shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value}
for dtype in number_dtypes
for size in [1, 5, 10]
for fill_value in [None, -1]
for shape in nonempty_array_shapes))
def testUniqueSize(self, shape, dtype, size):
def testUniqueSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
kwds = dict(return_index=True, return_inverse=True, return_counts=True)
@ -2329,12 +2330,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
u, ind, counts = u[:size], ind[:size], counts[:size]
else:
extra = size - len(u)
u = np.concatenate([u, np.full(extra, u[0], u.dtype)])
ind = np.concatenate([ind, np.full(extra, ind[0], ind.dtype)])
counts = np.concatenate([counts, np.zeros(extra, counts.dtype)])
u = np.pad(u, (0, extra), constant_values=u[0] if fill_value is None else fill_value)
ind = np.pad(ind, (0, extra), constant_values=ind[0] if fill_value is None else fill_value)
counts = np.pad(counts, (0, extra), constant_values=0)
return u, ind, inv, counts
jnp_fun = lambda x: jnp.unique(x, size=size, **kwds)
jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)