rocm_jax/jax/_src/numpy/setops.py
2024-06-04 17:39:13 +04:00

971 lines
39 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
import math
import operator
from typing import cast, NamedTuple
import numpy as np
from jax import jit
from jax import lax
from jax._src import core
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import (
append, arange, array, asarray, concatenate, diff,
empty, full_like, lexsort, moveaxis, nonzero, ones, ravel,
sort, where, zeros)
from jax._src.numpy.reductions import any, cumsum
from jax._src.numpy.ufuncs import isnan
from jax._src.numpy.util import check_arraylike
from jax._src.util import canonicalize_axis
from jax._src.typing import Array, ArrayLike
_lax_const = lax_internal._const
@partial(jit, static_argnames=('invert',))
def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
check_arraylike("in1d", ar1, ar2)
ar1_flat = ravel(ar1)
ar2_flat = ravel(ar2)
# Note: an algorithm based on searchsorted has better scaling, but in practice
# is very slow on accelerators because it relies on lax control flow. If XLA
# ever supports binary search natively, we should switch to this:
# ar2_flat = jnp.sort(ar2_flat)
# ind = jnp.searchsorted(ar2_flat, ar1_flat)
# if invert:
# return ar1_flat != ar2_flat[ind]
# else:
# return ar1_flat == ar2_flat[ind]
if invert:
return (ar1_flat[:, None] != ar2_flat[None, :]).all(-1)
else:
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)
def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set difference of two 1D arrays.
JAX implementation of :func:`numpy.setdiff1d`.
Because the size of the output of ``setdiff1d`` is data-dependent, the function
semantics are not typically compatible with :func:`~jax.jit` and other JAX
transformations. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.setdiff1d`` to be used in such contexts.
transformations.
Args:
ar1: first array of elements to be differenced.
ar2: second array of elements to be differenced.
assume_unique: if True, assume the input arrays contain unique values. This allows
a more efficient implementation, but if ``assume_unique`` is True and the input
arrays contain duplicates, the behavior is undefined. default: False.
size: if specified, return only the first ``size`` sorted elements. If there are fewer
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum value.
Returns:
an array containing the set difference of elements in the input array: i.e. the elements
in ``ar1`` that are not contained in ``ar2``.
See also:
- :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays.
- :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays.
- :func:`jax.numpy.union1d`: the set union of two 1D arrays.
Examples:
Computing the set difference of two arrays:
>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setdiff1d(ar1, ar2)
Array([1, 2], dtype=int32)
Because the output shape is dynamic, this will fail under :func:`~jax.jit` and other
transformations:
>>> jax.jit(jnp.setdiff1d)(ar1, ar2) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size``
argument:
>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size'])
>>> jit_setdiff1d(ar1, ar2, size=2)
Array([1, 2], dtype=int32)
If ``size`` is too small, the difference is truncated:
>>> jit_setdiff1d(ar1, ar2, size=1)
Array([1], dtype=int32)
If ``size`` is too large, then the output is padded with ``fill_value``:
>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0)
Array([1, 2, 0, 0], dtype=int32)
"""
check_arraylike("setdiff1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
else:
size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()")
arr1 = asarray(ar1)
fill_value = asarray(0 if fill_value is None else fill_value, dtype=arr1.dtype)
if arr1.size == 0:
return full_like(arr1, fill_value, shape=size or 0)
if not assume_unique:
arr1 = cast(Array, unique(arr1, size=size and arr1.size))
mask = _in1d(arr1, ar2, invert=True)
if size is None:
return arr1[mask]
else:
if not (assume_unique or size is None):
# Set mask to zero at locations corresponding to unique() padding.
n_unique = arr1.size + 1 - (arr1 == arr1[0]).sum()
mask = where(arange(arr1.size) < n_unique, mask, False)
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
def union1d(ar1: ArrayLike, ar2: ArrayLike,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set union of two 1D arrays.
JAX implementation of :func:`numpy.union1d`.
Because the size of the output of ``union1d`` is data-dependent, the function
semantics are not typically compatible with :func:`~jax.jit` and other JAX
transformations. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.union1d`` to be used in such contexts.
transformations.
Args:
ar1: first array of elements to be unioned.
ar2: second array of elements to be unioned
size: if specified, return only the first ``size`` sorted elements. If there are fewer
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum value.
Returns:
an array containing the union of elements in the input array.
See also:
- :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays.
- :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays.
- :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays.
Examples:
Computing the union of two arrays:
>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.union1d(ar1, ar2)
Array([1, 2, 3, 4, 5, 6], dtype=int32)
Because the output shape is dynamic, this will fail under :func:`~jax.jit` and other
transformations:
>>> jax.jit(jnp.union1d)(ar1, ar2) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
In order to ensure statically-known output shapes, you can pass a static ``size``
argument:
>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size'])
>>> jit_union1d(ar1, ar2, size=6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)
If ``size`` is too small, the union is truncated:
>>> jit_union1d(ar1, ar2, size=4)
Array([1, 2, 3, 4], dtype=int32)
If ``size`` is too large, then the output is padded with ``fill_value``:
>>> jit_union1d(ar1, ar2, size=8, fill_value=0)
Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)
"""
check_arraylike("union1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()")
else:
size = core.concrete_or_error(operator.index, size, "The error arose in union1d()")
out = unique(concatenate((ar1, ar2), axis=None), size=size,
fill_value=fill_value)
return cast(Array, out)
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array:
"""Compute the set-wise xor of elements in two arrays.
JAX implementation of :func:`numpy.setxor1d`.
Because the size of the output of ``setxor1d`` is data-dependent, the function is not
compatible with JIT or other JAX transformations.
Args:
ar1: first array of values to intersect.
ar2: second array of values to intersect.
assume_unique: if True, assume the input arrays contain unique values. This allows
a more efficient implementation, but if ``assume_unique`` is True and the input
arrays contain duplicates, the behavior is undefined. default: False.
Returns:
An array of values that are found in exactly one of the input arrays.
See also:
- :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays.
- :func:`jax.numpy.union1d`: the set union of two 1D arrays.
- :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays.
Examples:
>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setxor1d(ar1, ar2)
Array([1, 2, 5, 6], dtype=int32)
"""
check_arraylike("setxor1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
ar1 = ravel(ar1)
ar2 = ravel(ar2)
if not assume_unique:
ar1 = unique(ar1)
ar2 = unique(ar2)
aux = concatenate((ar1, ar2))
if aux.size == 0:
return aux
aux = sort(aux)
flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
return aux[flag[1:] & flag[:-1]]
@partial(jit, static_argnames=['return_indices'])
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> tuple[Array, ...]:
# JIT-compatible helper function for intersect1d
ar = concatenate((ar1, ar2))
if return_indices:
iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0)
aux, indices = lax.sort_key_val(ar, iota)
else:
aux = sort(ar)
mask = aux[1:] == aux[:-1]
if return_indices:
return aux, mask, indices
else:
return aux, mask
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return_indices: bool = False) -> Array | tuple[Array, Array, Array]:
"""Compute the set intersection of two 1D arrays.
JAX implementation of :func:`numpy.intersect1d`.
Because the size of the output of ``intersect1d`` is data-dependent, the function is not
compatible with JIT or other JAX transformations.
Args:
ar1: first array of values to intersect.
ar2: second array of values to intersect.
assume_unique: if True, assume the input arrays contain unique values. This allows
a more efficient implementation, but if ``assume_unique`` is True and the input
arrays contain duplicates, the behavior is undefined. default: False.
return_indices: If True, return arrays of indices specifying where the intersected
values first appear in the input arrays.
Returns:
An array ``intersection``, or if ``return_indices=True``, a tuple of arrays
``(intersection, ar1_indices, ar2_indices)``. Returned values are
- ``intersection``:
A 1D array containing each value that appears in both ``ar1`` and ``ar2``.
- ``ar1_indices``:
*(returned if return_indices=True)* an array of shape ``intersection.shape`` containing
the indices in flattened ``ar1`` of values in ``intersection``. For 1D inputs,
``intersection`` is equivalent to ``ar1[ar1_indices]``.
- ``ar2_indices``:
*(returned if return_indices=True)* an array of shape ``intersection.shape`` containing
the indices in flattened ``ar2`` of values in ``intersection``. For 1D inputs,
``intersection`` is equivalent to ``ar2[ar2_indices]``.
See also:
- :func:`jax.numpy.union1d`: the set union of two 1D arrays.
- :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays.
- :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays.
Examples:
>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.intersect1d(ar1, ar2)
Array([3, 4], dtype=int32)
Computing intersection with indices:
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
>>> intersection
Array([3, 4], dtype=int32)
``ar1_indices`` gives the indices of the intersected values within ``ar1``:
>>> ar1_indices
Array([2, 3], dtype=int32)
>>> jnp.all(intersection == ar1[ar1_indices])
Array(True, dtype=bool)
``ar2_indices`` gives the indices of the intersected values within ``ar2``:
>>> ar2_indices
Array([0, 1], dtype=int32)
>>> jnp.all(intersection == ar2[ar2_indices])
Array(True, dtype=bool)
"""
check_arraylike("intersect1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
if not assume_unique:
if return_indices:
ar1, ind1 = unique(ar1, return_index=True)
ar2, ind2 = unique(ar2, return_index=True)
else:
ar1 = unique(ar1)
ar2 = unique(ar2)
else:
ar1 = ravel(ar1)
ar2 = ravel(ar2)
if return_indices:
aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices)
else:
aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)
int1d = aux[:-1][mask]
if return_indices:
ar1_indices = aux_sort_indices[:-1][mask]
ar2_indices = aux_sort_indices[1:][mask] - np.size(ar1)
if not assume_unique:
ar1_indices = ind1[ar1_indices]
ar2_indices = ind2[ar2_indices]
return int1d, ar1_indices, ar2_indices
else:
return int1d
def isin(element: ArrayLike, test_elements: ArrayLike,
assume_unique: bool = False, invert: bool = False) -> Array:
"""Determine whether elements in ``element`` appear in ``test_elements``.
JAX implementation of :func:`numpy.isin`.
Args:
element: input array of elements for which membership will be checked.
test_elements: N-dimensional array of test values to check for the presence of
each element.
invert: If True, return ``~isin(element, test_elements)``. Default is False.
assume_unique: unused by JAX
Returns:
A boolean array of shape ``element.shape`` that specifies whether each element
appears in ``test_elements``.
Examples:
>>> elements = jnp.array([1, 2, 3, 4])
>>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]])
>>> jnp.isin(elements, test_elements)
Array([ True, False, True, False], dtype=bool)
"""
del assume_unique # unused
check_arraylike("isin", element, test_elements)
result = _in1d(element, test_elements, invert=invert)
return result.reshape(np.shape(element))
### SetOps
UNIQUE_SIZE_HINT = (
"To make jnp.unique() compatible with JIT and other transforms, you can specify "
"a concrete value for the size argument, which will determine the output size.")
@partial(jit, static_argnames=['axis', 'equal_nan'])
def _unique_sorted_mask(ar: Array, axis: int, equal_nan: bool) -> tuple[Array, Array, Array]:
aux = moveaxis(ar, axis, 0)
if np.issubdtype(aux.dtype, np.complexfloating):
# Work around issue in sorting of complex numbers with Nan only in the
# imaginary component. This can be removed if sorting in this situation
# is fixed to match numpy.
aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
size, *out_shape = aux.shape
if math.prod(out_shape) == 0:
size = 1
perm = zeros(1, dtype=int)
else:
perm = lexsort(aux.reshape(size, math.prod(out_shape)).T[::-1])
aux = aux[perm]
if aux.size:
if dtypes.issubdtype(aux.dtype, np.inexact) and equal_nan:
# This is appropriate for both float and complex due to the documented behavior of np.unique:
# See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
else:
neq = lax.ne
mask = ones(size, dtype=bool).at[1:].set(any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
else:
mask = zeros(size, dtype=bool)
return aux, mask, perm
def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, equal_nan: bool = True, size: int | None = None,
fill_value: ArrayLike | None = None, return_true_size: bool = False
) -> Array | tuple[Array, ...]:
"""
Find the unique elements of an array along a particular axis.
"""
axis = canonicalize_axis(axis, ar.ndim)
if ar.shape[axis] == 0 and size and fill_value is None:
raise ValueError(
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
aux, mask, perm = _unique_sorted_mask(ar, axis, equal_nan)
if size is None:
ind = core.concrete_or_error(None, mask,
"The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT)
else:
ind = nonzero(mask, size=size)[0]
result = aux[ind] if aux.size else aux
if size is not None and fill_value is not None:
fill_value = asarray(fill_value, dtype=result.dtype)
if result.shape[0]:
valid = lax.expand_dims(arange(size) < mask.sum(), tuple(range(1, result.ndim)))
result = where(valid, result, fill_value)
else:
result = full_like(result, fill_value, shape=(size, *result.shape[1:]))
result = moveaxis(result, 0, axis)
ret: tuple[Array, ...] = (result,)
if return_index:
if aux.size:
ret += (perm[ind],)
else:
ret += (perm,)
if return_inverse:
if aux.size:
imask = cumsum(mask) - 1
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(dtypes.int_))
inv_idx = inv_idx.at[perm].set(imask)
else:
inv_idx = zeros(ar.shape[axis], dtype=int)
if ar.ndim > 1:
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
ret += (inv_idx,)
if return_counts:
if aux.size:
if size is None:
idx = append(nonzero(mask)[0], mask.size)
else:
idx = nonzero(mask, size=size + 1)[0]
idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
ret += (diff(idx),)
elif ar.shape[axis]:
ret += (array([ar.shape[axis]], dtype=dtypes.canonicalize_dtype(dtypes.int_)),)
else:
ret += (empty(0, dtype=int),)
if return_true_size:
# Useful for internal uses of unique().
ret += (mask.sum(),)
return ret[0] if len(ret) == 1 else ret
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, axis: int | None = None,
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None):
"""Return the unique values from an array.
JAX implementation of :func:`jax.numpy.unique`.
Because the size of the output of ``unique`` is data-dependent, the function
semantics are not typically compatible with :func:`~jax.jit` and other JAX
transformations. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.unique`` to be used in such contexts.
Args:
ar: N-dimensional array from which unique values will be extracted.
return_index: if True, also return the indices in ``ar`` where each value occurs
return_inverse: if True, also return the indices that can be used to reconstruct
``ar`` from the unique values.
return_counts: if True, also return the number of occurances of each unique value.
axis: if specified, compute unique values along the specified axis. If None (default),
then flatten ``ar`` before computing the unique values.
equal_nan: if True, consider NaN values equivalent when determining uniqueness.
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
Returns:
An array or tuple of arrays, depending on the values of ``return_index``, ``return_inverse``,
and ``return_counts``. Returned values are
- ``unique_values``:
if ``axis`` is None, a 1D array of length ``n_unique``, If ``axis`` is
specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``.
- ``unique_index``:
*(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains
the indices of the first occurance of each unique value in ``ar``. For 1D inputs,
``ar[unique_index]`` is equivlent to ``unique_values``.
- ``unique_inverse``:
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
Contains the indices within ``unique_values`` of each value in ``ar``. For 1D inputs,
``unique_values[unique_inverse]`` is equivalent to ``ar``.
- ``unique_counts``:
*(returned only if return_counts is True)* An array of shape ``(n_unique,)``.
Contains the number of occurances of each unique value in ``ar``.
See also:
- :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``.
- :func:`jax.numpy.unique_inverse`: shortcut to ``unique(arr, return_inverse=True)``.
- :func:`jax.numpy.unique_all`: shortcut to ``unique`` with all return values.
- :func:`jax.numpy.unique_values`: like ``unique``, but no optional return values.
Examples:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> jnp.unique(x)
Array([1, 3, 4], dtype=int32)
**JIT compilation & the size argument**
If you try this under :func:`~jax.jit` or another transformation, you will get an
error because the output shape is dynamic:
>>> jax.jit(jnp.unique)(x) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5].
The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.
The issue is that the output of transformed functions must have static shapes.
In order to make this work, you can pass a static ``size`` parameter:
>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size'])
>>> jit_unique(x, size=3)
Array([1, 3, 4], dtype=int32)
If your static size is smaller than the true number of unique values, they will be truncated.
>>> jit_unique(x, size=2)
Array([1, 3], dtype=int32)
If the static size is larger than the true number of unique values, they will be padded with
``fill_value``, which defaults to the minimum unique value:
>>> jit_unique(x, size=5)
Array([1, 3, 4, 1, 1], dtype=int32)
>>> jit_unique(x, size=5, fill_value=0)
Array([1, 3, 4, 0, 0], dtype=int32)
**Multi-dimensional unique values**
If you pass a multi-dimensional array to ``unique``, it will be flattened by default:
>>> M = jnp.array([[1, 2],
... [2, 3],
... [1, 2]])
>>> jnp.unique(M)
Array([1, 2, 3], dtype=int32)
If you pass an ``axis`` keyword, you can find unique *slices* of the array along
that axis:
>>> jnp.unique(M, axis=0)
Array([[1, 2],
[2, 3]], dtype=int32)
**Returning indices**
If you set ``return_index=True``, then ``unique`` returns the indices of the
first occurance of each unique value:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, indices = jnp.unique(x, return_index=True)
>>> print(values)
[1 3 4]
>>> print(indices)
[2 0 1]
>>> jnp.all(values == x[indices])
Array(True, dtype=bool)
In multiple dimensions, the unique values can be extracted with :func:`jax.numpy.take`
evaluated along the specified axis:
>>> values, indices = jnp.unique(M, axis=0, return_index=True)
>>> jnp.all(values == jnp.take(M, indices, axis=0))
Array(True, dtype=bool)
**Returning inverse**
If you set ``return_inverse=True``, then ``unique`` returns the indices within the
unique values for every entry in the input array:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, inverse = jnp.unique(x, return_inverse=True)
>>> print(values)
[1 3 4]
>>> print(inverse)
[1 2 0 1 0]
>>> jnp.all(values[inverse] == x)
Array(True, dtype=bool)
In multiple dimensions, the input can be reconstructed using
:func:`jax.numpy.take_along_axis`:
>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
>>> jnp.all(jnp.take_along_axis(values, inverse, axis=0) == M)
Array(True, dtype=bool)
**Returning counts**
If you set ``return_counts=True``, then ``unique`` returns the number of occurances
within the input for every unique value:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, counts = jnp.unique(x, return_counts=True)
>>> print(values)
[1 3 4]
>>> print(counts)
[2 2 1]
For multi-dimensional arrays, this also returns a 1D array of counts
indicating number of occurances along the specified axis:
>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
>>> print(values)
[[1 2]
[2 3]]
>>> print(counts)
[2 1]
"""
check_arraylike("unique", ar)
if size is None:
ar = core.concrete_or_error(None, ar,
"The error arose for the first argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
else:
size = core.concrete_or_error(operator.index, size,
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
arr = asarray(ar)
arr_shape = arr.shape
if axis is None:
axis_int: int = 0
arr = arr.flatten()
else:
axis_int = canonicalize_axis(axis, arr.ndim)
result = _unique(arr, axis_int, return_index, return_inverse, return_counts,
equal_nan=equal_nan, size=size, fill_value=fill_value)
if return_inverse and axis is None:
idx = 2 if return_index else 1
result = (*result[:idx], result[idx].reshape(arr_shape), *result[idx + 1:])
return result
class _UniqueAllResult(NamedTuple):
"""Struct returned by :func:`jax.numpy.unique_all`."""
values: Array
indices: Array
inverse_indices: Array
counts: Array
class _UniqueCountsResult(NamedTuple):
"""Struct returned by :func:`jax.numpy.unique_counts`."""
values: Array
counts: Array
class _UniqueInverseResult(NamedTuple):
"""Struct returned by :func:`jax.numpy.unique_inverse`."""
values: Array
inverse_indices: Array
def unique_all(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueAllResult:
"""Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of :func:`numpy.unique_all`; this is equivalent to calling
:func:`jax.numpy.unique` with `return_index`, `return_inverse`, `return_counts`,
and `equal_nan` set to True.
Because the size of the output of ``unique_all`` is data-dependent, the function
semantics are not typically compatible with :func:`~jax.jit` and other JAX
transformations. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.unique`` to be used in such contexts.
Args:
x: N-dimensional array from which unique values will be extracted.
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
Returns:
A tuple ``(values, indices, inverse_indices, counts)``, with the following properties:
- ``values``:
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
- ``indices``:
An array of shape ``(n_unique,)``. Contains the indices of the first occurance of
each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivlent to ``values``.
- ``inverse_indices``:
An array of shape ``x.shape``. Contains the indices within ``values`` of each value
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
- ``counts``:
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
value in ``x``.
See also:
- :func:`jax.numpy.unique`: general function for computing unique values.
- :func:`jax.numpy.unique_values`: compute only ``values``.
- :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``.
- :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``.
Examples:
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_all(x)
The result is a :class:`~typing.NamedTuple` with four named attributes.
The ``values`` attribue contains the unique values from the array:
>>> result.values
Array([1, 3, 4], dtype=int32)
The ``indices`` attribute contains the indices of the unique ``values`` within
the input array:
>>> result.indices
Array([2, 0, 1], dtype=int32)
>>> jnp.all(result.values == x[result.indices])
Array(True, dtype=bool)
The ``inverse_indices`` attribute contains the indices of the input within ``values``:
>>> result.inverse_indices
Array([1, 2, 0, 1, 0], dtype=int32)
>>> jnp.all(x == result.values[result.inverse_indices])
Array(True, dtype=bool)
The ``counts`` attribute contains the counts of each unique value in the input:
>>> result.counts
Array([2, 2, 1], dtype=int32)
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
"""
check_arraylike("unique_all", x)
values, indices, inverse_indices, counts = unique(
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False,
size=size, fill_value=fill_value)
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
def unique_counts(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueCountsResult:
"""Return unique values from x, along with counts.
JAX implementation of :func:`numpy.unique_counts`; this is equivalent to calling
:func:`jax.numpy.unique` with `return_counts` and `equal_nan` set to True.
Because the size of the output of ``unique_counts`` is data-dependent, the function
semantics are not typically compatible with :func:`~jax.jit` and other JAX
transformations. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.unique`` to be used in such contexts.
Args:
x: N-dimensional array from which unique values will be extracted.
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
Returns:
A tuple ``(values, counts)``, with the following properties:
- ``values``:
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
- ``counts``:
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
value in ``x``.
See also:
- :func:`jax.numpy.unique`: general function for computing unique values.
- :func:`jax.numpy.unique_values`: compute only ``values``.
- :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``.
- :func:`jax.numpy.unique_all`: compute ``values``, ``indices``, ``inverse_indices``,
and ``counts``.
Examples:
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_counts(x)
The result is a :class:`~typing.NamedTuple` with two named attributes.
The ``values`` attribue contains the unique values from the array:
>>> result.values
Array([1, 3, 4], dtype=int32)
The ``counts`` attribute contains the counts of each unique value in the input:
>>> result.counts
Array([2, 2, 1], dtype=int32)
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
"""
check_arraylike("unique_counts", x)
values, counts = unique(x, return_counts=True, equal_nan=False,
size=size, fill_value=fill_value)
return _UniqueCountsResult(values=values, counts=counts)
def unique_inverse(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> _UniqueInverseResult:
"""Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of :func:`numpy.unique_inverse`; this is equivalent to calling
:func:`jax.numpy.unique` with `return_inverse` and `equal_nan` set to True.
Because the size of the output of ``unique_inverse`` is data-dependent, the function
semantics are not typically compatible with :func:`~jax.jit` and other JAX
transformations. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.unique`` to be used in such contexts.
Args:
x: N-dimensional array from which unique values will be extracted.
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
Returns:
A tuple ``(values, indices, inverse_indices, counts)``, with the following properties:
- ``values``:
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
- ``inverse_indices``:
An array of shape ``x.shape``. Contains the indices within ``values`` of each value
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
See also:
- :func:`jax.numpy.unique`: general function for computing unique values.
- :func:`jax.numpy.unique_values`: compute only ``values``.
- :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``.
- :func:`jax.numpy.unique_all`: compute ``values``, ``indices``, ``inverse_indices``,
and ``counts``.
Examples:
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_inverse(x)
The result is a :class:`~typing.NamedTuple` with two named attributes.
The ``values`` attribute contains the unique values from the array:
>>> result.values
Array([1, 3, 4], dtype=int32)
The ``indices`` attribute contains the indices of the unique ``values`` within
the input array:
The ``inverse_indices`` attribute contains the indices of the input within ``values``:
>>> result.inverse_indices
Array([1, 2, 0, 1, 0], dtype=int32)
>>> jnp.all(x == result.values[result.inverse_indices])
Array(True, dtype=bool)
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
"""
check_arraylike("unique_inverse", x)
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False,
size=size, fill_value=fill_value)
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
def unique_values(x: ArrayLike, /, *, size: int | None = None,
fill_value: ArrayLike | None = None) -> Array:
"""Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of :func:`numpy.unique_values`; this is equivalent to calling
:func:`jax.numpy.unique` with `equal_nan` set to True.
Because the size of the output of ``unique_values`` is data-dependent, the function
semantics are not typically compatible with :func:`~jax.jit` and other JAX
transformations. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.unique`` to be used in such contexts.
Args:
x: N-dimensional array from which unique values will be extracted.
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
Returns:
An array ``values`` of shape ``(n_unique,)`` containing the unique values from ``x``.
See also:
- :func:`jax.numpy.unique`: general function for computing unique values.
- :func:`jax.numpy.unique_values`: compute only ``values``.
- :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``.
- :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``.
Examples:
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1])
>>> jnp.unique_values(x)
Array([1, 3, 4], dtype=int32)
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
"""
check_arraylike("unique_values", x)
return cast(Array, unique(x, equal_nan=False, size=size, fill_value=fill_value))