mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jnp.unique: add fill_value for when size is not None
This commit is contained in:
parent
3c117fd6ed
commit
0b93c46c71
@ -38,6 +38,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* An optimized C++ code-path improving the dispatch time for `pmap` is now the
|
||||
default when using jaxlib 0.1.72 or newer. The feature can be disabled using
|
||||
the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
|
||||
* `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`#8121`)
|
||||
|
||||
## jax 0.2.21 (Sept 23, 2021)
|
||||
* [GitHub
|
||||
|
@ -5535,7 +5535,7 @@ def _unique1d_sorted_mask(ar, optional_indices=False):
|
||||
return aux, mask, perm
|
||||
|
||||
def _unique1d(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, size=None):
|
||||
return_counts=False, size=None, fill_value=None):
|
||||
"""
|
||||
Find the unique elements of an array, ignoring shape.
|
||||
"""
|
||||
@ -5545,9 +5545,16 @@ def _unique1d(ar, return_index=False, return_inverse=False,
|
||||
aux, mask, perm = _unique1d_sorted_mask(ar, return_index or return_inverse)
|
||||
ind = mask if size is None else nonzero(mask, size=size)
|
||||
|
||||
ret = (aux[ind],)
|
||||
result = aux[ind]
|
||||
if size is not None and fill_value is not None:
|
||||
result = where(arange(size) >= mask.sum(), fill_value, result)
|
||||
|
||||
ret = (result,)
|
||||
if return_index:
|
||||
ret += (perm[ind],)
|
||||
perm_ind = perm[ind]
|
||||
if size is not None and fill_value is not None:
|
||||
perm_ind = where(arange(size) >= mask.sum(), fill_value, perm_ind)
|
||||
ret += (perm_ind,)
|
||||
if return_inverse:
|
||||
imask = cumsum(mask) - 1
|
||||
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
|
||||
@ -5616,17 +5623,17 @@ def _unique_axis(ar, axis, return_index=False, return_inverse=False,
|
||||
_UNIQUE_DOC = """\
|
||||
Because the size of the output of ``unique`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional `size` argument which
|
||||
specifies the size of the data-dependent output arrays: it must be specified statically for
|
||||
``jnp.unique`` to be traced. If specified, the first `size` unique elements will be returned;
|
||||
if there are fewer unique elements than `size` indicates, the return value will be padded with
|
||||
the minimum value in the input array.
|
||||
specifies the size of the data-dependent output arrays: it must be specified statically
|
||||
for ``jnp.unique`` to be traced. If specified, the first `size` unique elements will be
|
||||
returned; if there are fewer unique elements than `size` indicates, the return value will
|
||||
be padded with `fill_value`, which defaults to the minimum value in the input array.
|
||||
|
||||
The `size` cannot currently be used with the `axis` argument."""
|
||||
|
||||
|
||||
@_wraps(np.unique, skip_params=['axis'], lax_description=_UNIQUE_DOC)
|
||||
def unique(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, axis: Optional[int] = None, *, size=None):
|
||||
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
|
||||
_check_arraylike("unique", ar)
|
||||
|
||||
# TODO(jakevdp): call _check_arraylike on input.
|
||||
@ -5641,7 +5648,7 @@ def unique(ar, return_index=False, return_inverse=False,
|
||||
|
||||
ar = asarray(ar)
|
||||
if axis is None:
|
||||
ret = _unique1d(ar, return_index, return_inverse, return_counts, size=size)
|
||||
ret = _unique1d(ar, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
|
||||
else:
|
||||
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
|
||||
ret = _unique_axis(ar, axis, return_index, return_inverse, return_counts)
|
||||
|
@ -2312,13 +2312,14 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_size={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), size),
|
||||
"shape": shape, "dtype": dtype, "size": size}
|
||||
{"testcase_name": "_{}_size={}_fill_value={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), size, fill_value),
|
||||
"shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value}
|
||||
for dtype in number_dtypes
|
||||
for size in [1, 5, 10]
|
||||
for fill_value in [None, -1]
|
||||
for shape in nonempty_array_shapes))
|
||||
def testUniqueSize(self, shape, dtype, size):
|
||||
def testUniqueSize(self, shape, dtype, size, fill_value):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
kwds = dict(return_index=True, return_inverse=True, return_counts=True)
|
||||
@ -2329,12 +2330,12 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
u, ind, counts = u[:size], ind[:size], counts[:size]
|
||||
else:
|
||||
extra = size - len(u)
|
||||
u = np.concatenate([u, np.full(extra, u[0], u.dtype)])
|
||||
ind = np.concatenate([ind, np.full(extra, ind[0], ind.dtype)])
|
||||
counts = np.concatenate([counts, np.zeros(extra, counts.dtype)])
|
||||
u = np.pad(u, (0, extra), constant_values=u[0] if fill_value is None else fill_value)
|
||||
ind = np.pad(ind, (0, extra), constant_values=ind[0] if fill_value is None else fill_value)
|
||||
counts = np.pad(counts, (0, extra), constant_values=0)
|
||||
return u, ind, inv, counts
|
||||
|
||||
jnp_fun = lambda x: jnp.unique(x, size=size, **kwds)
|
||||
jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds)
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
Loading…
x
Reference in New Issue
Block a user