jnp.unique: improve efficiency & consolidate implementation

This commit is contained in:
Jake VanderPlas 2021-10-14 15:28:32 -07:00
parent aaf3bb789e
commit a3a6a5b137

View File

@ -5544,89 +5544,34 @@ def take_along_axis(arr, indices, axis: Optional[int]):
### SetOps
@partial(jit, static_argnums=1)
def _unique1d_sorted_mask(ar, optional_indices=False):
"""
Helper function for unique which is jit-able
"""
ar = ar.flatten()
if optional_indices:
aux, perm = lax.sort_key_val(ar, lax.iota(int, len(ar)))
else:
perm = np.empty(0, dtype=int)
aux = ar.sort()
mask = ones(aux.shape, dtype=bool_).at[1:].set(aux[1:] != aux[:-1])
return aux, mask, perm
def _unique1d(ar, return_index=False, return_inverse=False,
return_counts=False, size=None, fill_value=None):
"""
Find the unique elements of an array, ignoring shape.
"""
if np.size(ar) == 0 and size is not None and size > 0:
raise ValueError("jnp.unique(): Cannot pass nonzero size for zero-sized array.")
aux, mask, perm = _unique1d_sorted_mask(ar, return_index or return_inverse)
ind = mask if size is None else nonzero(mask, size=size)
result = aux[ind]
if size is not None and fill_value is not None:
result = where(arange(size) >= mask.sum(), fill_value, result)
ret = (result,)
if return_index:
ret += (perm[ind],)
if return_inverse:
imask = cumsum(mask) - 1
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
inv_idx = inv_idx.at[perm].set(imask)
ret += (inv_idx,)
if return_counts:
if size is None:
idx = append(nonzero(mask)[0], mask.size)
else:
idx = nonzero(mask, size=size + 1)[0]
idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
ret += (diff(idx),)
return ret
@partial(jit, static_argnums=1)
def _unique_axis_sorted_mask(ar, axis):
def _unique_sorted_mask(ar, axis):
aux = moveaxis(ar, axis, 0)
size, *out_shape = aux.shape
aux = aux.reshape(size, _prod(out_shape)).T
if aux.shape[0] == 0:
if _prod(out_shape) == 0:
size = 1
perm = zeros(1, dtype=int)
else:
perm = lexsort(aux[::-1])
aux = aux[:, perm]
perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
aux = aux[perm]
if aux.size:
mask = ones(size, dtype=bool).at[1:].set(any(aux[:, 1:] != aux[:, :-1], 0))
mask = ones(size, dtype=bool).at[1:].set(any(aux[1:] != aux[:-1], tuple(range(1, aux.ndim))))
else:
mask = zeros(size, dtype=bool)
return aux, mask, perm
def _unique_axis(ar, axis, return_index=False, return_inverse=False,
return_counts=False, size=None, fill_value=None):
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False,
size=None, fill_value=None, return_true_size=False):
"""
Find the unique elements of an array along a particular axis.
"""
aux, mask, perm = _unique_axis_sorted_mask(ar, axis)
out_shape = ar.shape[:axis] + ar.shape[axis + 1:]
aux, mask, perm = _unique_sorted_mask(ar, axis)
ind = mask if size is None else nonzero(mask, size=size)[0]
result = aux[:, ind]
result = aux[ind] if aux.size else aux
if size is not None and fill_value is not None:
if _ndim(fill_value):
fill_value = broadcast_to(fill_value, out_shape).reshape(_prod(out_shape), 1)
result = where(arange(size) >= mask.sum(), fill_value, result)
leading_dim = size if size is not None else mask.sum() or aux.shape[1]
result = moveaxis(result.T.reshape(leading_dim, *out_shape), 0, axis)
valid = lax.expand_dims(arange(size) < mask.sum(), tuple(range(1, result.ndim)))
result = where(valid, result, fill_value)
result = moveaxis(result, 0, axis)
ret = (result,)
if return_index:
@ -5654,8 +5599,11 @@ def _unique_axis(ar, axis, return_index=False, return_inverse=False,
ret += (array([ar.shape[axis]]),)
else:
ret += (empty(0, dtype=int),)
if return_true_size:
# Useful for internal uses of unique().
ret += (mask.sum(),)
return ret[0] if len(ret) == 1 else ret
return ret
_UNIQUE_DOC = """\
Because the size of the output of ``unique`` is data-dependent, the function is not
@ -5671,20 +5619,16 @@ along the specified axis of the input."""
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
_check_arraylike("unique", ar)
if size is None:
ar = core.concrete_or_error(None, ar, "The error arose for the first argument of jnp.unique()")
else:
size = core.concrete_or_error(operator.index, size, "The error arose for the size argument of jnp.unique()")
ar = asarray(ar)
if axis is None:
ret = _unique1d(ar, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
else:
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
ret = _unique_axis(ar, axis, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
return ret[0] if len(ret) == 1 else ret
axis = 0
ar = ar.flatten()
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
return _unique(ar, axis, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
### Indexing