mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #21215 from jakevdp:setops-docs
PiperOrigin-RevId: 633400796
This commit is contained in:
commit
e735a00cdc
@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from textwrap import dedent as _dedent
|
||||
from typing import cast, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
@ -34,7 +33,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
sort, where, zeros)
|
||||
from jax._src.numpy.reductions import any, cumsum
|
||||
from jax._src.numpy.ufuncs import isnan
|
||||
from jax._src.numpy.util import check_arraylike, implements
|
||||
from jax._src.numpy.util import check_arraylike
|
||||
from jax._src.util import canonicalize_axis
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
@ -61,21 +60,73 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
|
||||
else:
|
||||
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)
|
||||
|
||||
@implements(np.setdiff1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.setdiff1d`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` elements of the result will be returned. If there are
|
||||
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||
|
||||
def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
|
||||
"""Compute the set difference of two 1D arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.setdiff1d`.
|
||||
|
||||
Because the size of the output of ``setdiff1d`` is data-dependent, the function
|
||||
semantics are not typically compatible with :func:`~jax.jit` and other JAX
|
||||
transformations. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.setdiff1d`` to be used in such contexts.
|
||||
transformations.
|
||||
|
||||
Args:
|
||||
ar1: first array of elements to be differenced.
|
||||
ar2: second array of elements to be differenced.
|
||||
assume_unique: if True, assume the input arrays contain unique values. This allows
|
||||
a more efficient implementation, but if ``assume_unique`` is True and the input
|
||||
arrays contain duplicates, the behavior is undefined. default: False.
|
||||
size: if specified, return only the first ``size`` sorted elements. If there are fewer
|
||||
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum value.
|
||||
|
||||
Returns:
|
||||
an array containing the set difference of elements in the input array: i.e. the elements
|
||||
in ``ar1`` that are not contained in ``ar2``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays.
|
||||
- :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays.
|
||||
- :func:`jax.numpy.union1d`: the set union of two 1D arrays.
|
||||
|
||||
Examples:
|
||||
Computing the set difference of two arrays:
|
||||
|
||||
>>> ar1 = jnp.array([1, 2, 3, 4])
|
||||
>>> ar2 = jnp.array([3, 4, 5, 6])
|
||||
>>> jnp.setdiff1d(ar1, ar2)
|
||||
Array([1, 2], dtype=int32)
|
||||
|
||||
Because the output shape is dynamic, this will fail under :func:`~jax.jit` and other
|
||||
transformations:
|
||||
|
||||
>>> jax.jit(jnp.setdiff1d)(ar1, ar2) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
|
||||
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
|
||||
|
||||
In order to ensure statically-known output shapes, you can pass a static ``size``
|
||||
argument:
|
||||
|
||||
>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size'])
|
||||
>>> jit_setdiff1d(ar1, ar2, size=2)
|
||||
Array([1, 2], dtype=int32)
|
||||
|
||||
If ``size`` is too small, the difference is truncated:
|
||||
|
||||
>>> jit_setdiff1d(ar1, ar2, size=1)
|
||||
Array([1], dtype=int32)
|
||||
|
||||
If ``size`` is too large, then the output is padded with ``fill_value``:
|
||||
|
||||
>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0)
|
||||
Array([1, 2, 0, 0], dtype=int32)
|
||||
"""
|
||||
check_arraylike("setdiff1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
|
||||
@ -98,22 +149,68 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
|
||||
|
||||
|
||||
@implements(np.union1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``union1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.union1d`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` elements of the result will be returned. If there are
|
||||
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to the minimum
|
||||
value of the union."""))
|
||||
def union1d(ar1: ArrayLike, ar2: ArrayLike,
|
||||
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
|
||||
"""Compute the set union of two 1D arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.union1d`.
|
||||
|
||||
Because the size of the output of ``union1d`` is data-dependent, the function
|
||||
semantics are not typically compatible with :func:`~jax.jit` and other JAX
|
||||
transformations. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.union1d`` to be used in such contexts.
|
||||
transformations.
|
||||
|
||||
Args:
|
||||
ar1: first array of elements to be unioned.
|
||||
ar2: second array of elements to be unioned
|
||||
size: if specified, return only the first ``size`` sorted elements. If there are fewer
|
||||
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum value.
|
||||
|
||||
Returns:
|
||||
an array containing the union of elements in the input array.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays.
|
||||
- :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays.
|
||||
- :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays.
|
||||
|
||||
Examples:
|
||||
Computing the union of two arrays:
|
||||
|
||||
>>> ar1 = jnp.array([1, 2, 3, 4])
|
||||
>>> ar2 = jnp.array([3, 4, 5, 6])
|
||||
>>> jnp.union1d(ar1, ar2)
|
||||
Array([1, 2, 3, 4, 5, 6], dtype=int32)
|
||||
|
||||
Because the output shape is dynamic, this will fail under :func:`~jax.jit` and other
|
||||
transformations:
|
||||
|
||||
>>> jax.jit(jnp.union1d)(ar1, ar2) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
|
||||
The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
|
||||
|
||||
In order to ensure statically-known output shapes, you can pass a static ``size``
|
||||
argument:
|
||||
|
||||
>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size'])
|
||||
>>> jit_union1d(ar1, ar2, size=6)
|
||||
Array([1, 2, 3, 4, 5, 6], dtype=int32)
|
||||
|
||||
If ``size`` is too small, the union is truncated:
|
||||
|
||||
>>> jit_union1d(ar1, ar2, size=4)
|
||||
Array([1, 2, 3, 4], dtype=int32)
|
||||
|
||||
If ``size`` is too large, then the output is padded with ``fill_value``:
|
||||
|
||||
>>> jit_union1d(ar1, ar2, size=8, fill_value=0)
|
||||
Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)
|
||||
"""
|
||||
check_arraylike("union1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
|
||||
@ -125,11 +222,35 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
|
||||
return cast(Array, out)
|
||||
|
||||
|
||||
@implements(np.setxor1d, lax_description="""
|
||||
In the JAX version, the input arrays are explicitly flattened regardless
|
||||
of assume_unique value.
|
||||
""")
|
||||
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array:
|
||||
"""Compute the set-wise xor of elements in two arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.setxor1d`.
|
||||
|
||||
Because the size of the output of ``setxor1d`` is data-dependent, the function is not
|
||||
compatible with JIT or other JAX transformations.
|
||||
|
||||
Args:
|
||||
ar1: first array of values to intersect.
|
||||
ar2: second array of values to intersect.
|
||||
assume_unique: if True, assume the input arrays contain unique values. This allows
|
||||
a more efficient implementation, but if ``assume_unique`` is True and the input
|
||||
arrays contain duplicates, the behavior is undefined. default: False.
|
||||
|
||||
Returns:
|
||||
An array of values that are found in exactly one of the input arrays.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.intersect1d`: the set intersection of two 1D arrays.
|
||||
- :func:`jax.numpy.union1d`: the set union of two 1D arrays.
|
||||
- :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays.
|
||||
|
||||
Examples:
|
||||
>>> ar1 = jnp.array([1, 2, 3, 4])
|
||||
>>> ar2 = jnp.array([3, 4, 5, 6])
|
||||
>>> jnp.setxor1d(ar1, ar2)
|
||||
Array([1, 2, 5, 6], dtype=int32)
|
||||
"""
|
||||
check_arraylike("setxor1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
|
||||
@ -152,9 +273,7 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr
|
||||
|
||||
@partial(jit, static_argnames=['return_indices'])
|
||||
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> tuple[Array, ...]:
|
||||
"""
|
||||
Helper function for intersect1d which is jit-able
|
||||
"""
|
||||
# JIT-compatible helper function for intersect1d
|
||||
ar = concatenate((ar1, ar2))
|
||||
if return_indices:
|
||||
iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0)
|
||||
@ -169,9 +288,70 @@ def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: boo
|
||||
return aux, mask
|
||||
|
||||
|
||||
@implements(np.intersect1d)
|
||||
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return_indices: bool = False) -> Array | tuple[Array, Array, Array]:
|
||||
"""Compute the set intersection of two 1D arrays.
|
||||
|
||||
JAX implementation of :func:`numpy.intersect1d`.
|
||||
|
||||
Because the size of the output of ``intersect1d`` is data-dependent, the function is not
|
||||
compatible with JIT or other JAX transformations.
|
||||
|
||||
Args:
|
||||
ar1: first array of values to intersect.
|
||||
ar2: second array of values to intersect.
|
||||
assume_unique: if True, assume the input arrays contain unique values. This allows
|
||||
a more efficient implementation, but if ``assume_unique`` is True and the input
|
||||
arrays contain duplicates, the behavior is undefined. default: False.
|
||||
return_indices: If True, return arrays of indices specifying where the intersected
|
||||
values first appear in the input arrays.
|
||||
|
||||
Returns:
|
||||
An array ``intersection``, or if ``return_indices=True``, a tuple of arrays
|
||||
``(intersection, ar1_indices, ar2_indices)``. Returned values are
|
||||
|
||||
- ``intersection``:
|
||||
A 1D array containing each value that appears in both ``ar1`` and ``ar2``.
|
||||
- ``ar1_indices``:
|
||||
*(returned if return_indices=True)* an array of shape ``intersection.shape`` containing
|
||||
the indices in flattened ``ar1`` of values in ``intersection``. For 1D inputs,
|
||||
``intersection`` is equivalent to ``ar1[ar1_indices]``.
|
||||
- ``ar2_indices``:
|
||||
*(returned if return_indices=True)* an array of shape ``intersection.shape`` containing
|
||||
the indices in flattened ``ar2`` of values in ``intersection``. For 1D inputs,
|
||||
``intersection`` is equivalent to ``ar2[ar2_indices]``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.union1d`: the set union of two 1D arrays.
|
||||
- :func:`jax.numpy.setxor1d`: the set XOR of two 1D arrays.
|
||||
- :func:`jax.numpy.setdiff1d`: the set difference of two 1D arrays.
|
||||
|
||||
Examples:
|
||||
>>> ar1 = jnp.array([1, 2, 3, 4])
|
||||
>>> ar2 = jnp.array([3, 4, 5, 6])
|
||||
>>> jnp.intersect1d(ar1, ar2)
|
||||
Array([3, 4], dtype=int32)
|
||||
|
||||
Computing intersection with indices:
|
||||
|
||||
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
|
||||
>>> intersection
|
||||
Array([3, 4], dtype=int32)
|
||||
|
||||
``ar1_indices`` gives the indices of the intersected values within ``ar1``:
|
||||
|
||||
>>> ar1_indices
|
||||
Array([2, 3], dtype=int32)
|
||||
>>> jnp.all(intersection == ar1[ar1_indices])
|
||||
Array(True, dtype=bool)
|
||||
|
||||
``ar2_indices`` gives the indices of the intersected values within ``ar2``:
|
||||
|
||||
>>> ar2_indices
|
||||
Array([0, 1], dtype=int32)
|
||||
>>> jnp.all(intersection == ar2[ar2_indices])
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
check_arraylike("intersect1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
|
||||
@ -206,11 +386,29 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return int1d
|
||||
|
||||
|
||||
@implements(np.isin, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
def isin(element: ArrayLike, test_elements: ArrayLike,
|
||||
assume_unique: bool = False, invert: bool = False) -> Array:
|
||||
"""Determine whether elements in ``element`` appear in ``test_elements``.
|
||||
|
||||
JAX implementation of :func:`numpy.isin`.
|
||||
|
||||
Args:
|
||||
element: input array of elements for which membership will be checked.
|
||||
test_elements: N-dimensional array of test values to check for the presence of
|
||||
each element.
|
||||
invert: If True, return ``~isin(element, test_elements)``. Default is False.
|
||||
assume_unique: unused by JAX
|
||||
|
||||
Returns:
|
||||
A boolean array of shape ``element.shape`` that specifies whether each element
|
||||
appears in ``test_elements``.
|
||||
|
||||
Examples:
|
||||
>>> elements = jnp.array([1, 2, 3, 4])
|
||||
>>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]])
|
||||
>>> jnp.isin(elements, test_elements)
|
||||
Array([ True, False, True, False], dtype=bool)
|
||||
"""
|
||||
del assume_unique # unused
|
||||
check_arraylike("isin", element, test_elements)
|
||||
result = _in1d(element, test_elements, invert=invert)
|
||||
@ -312,23 +510,176 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
ret += (mask.sum(),)
|
||||
return ret[0] if len(ret) == 1 else ret
|
||||
|
||||
@implements(np.unique, skip_params=['axis'],
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``unique`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` unique elements will be returned. If there are fewer unique
|
||||
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``. The default is the minimum value
|
||||
along the specified axis of the input."""))
|
||||
|
||||
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
|
||||
return_counts: bool = False, axis: int | None = None,
|
||||
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None):
|
||||
"""Return the unique values from an array.
|
||||
|
||||
JAX implementation of :func:`jax.numpy.unique`.
|
||||
|
||||
Because the size of the output of ``unique`` is data-dependent, the function
|
||||
semantics are not typically compatible with :func:`~jax.jit` and other JAX
|
||||
transformations. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used in such contexts.
|
||||
|
||||
Args:
|
||||
ar: N-dimensional array from which unique values will be extracted.
|
||||
return_index: if True, also return the indices in ``ar`` where each value occurs
|
||||
return_inverse: if True, also return the indices that can be used to reconstruct
|
||||
``ar`` from the unique values.
|
||||
return_counts: if True, also return the number of occurances of each unique value.
|
||||
axis: if specified, compute unique values along the specified axis. If None (default),
|
||||
then flatten ``ar`` before computing the unique values.
|
||||
equal_nan: if True, consider NaN values equivalent when determining uniqueness.
|
||||
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
|
||||
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
|
||||
|
||||
Returns:
|
||||
An array or tuple of arrays, depending on the values of ``return_index``, ``return_inverse``,
|
||||
and ``return_counts``. Returned values are
|
||||
|
||||
- ``unique_values``:
|
||||
if ``axis`` is None, a 1D array of length ``n_unique``, If ``axis`` is
|
||||
specified, shape is ``(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])``.
|
||||
- ``unique_index``:
|
||||
*(returned only if return_index is True)* An array of shape ``(n_unique,)``. Contains
|
||||
the indices of the first occurance of each unique value in ``ar``. For 1D inputs,
|
||||
``ar[unique_index]`` is equivlent to ``unique_values``.
|
||||
- ``unique_inverse``:
|
||||
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
|
||||
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
|
||||
Contains the indices within ``unique_values`` of each value in ``ar``. For 1D inputs,
|
||||
``unique_values[unique_inverse]`` is equivalent to ``ar``.
|
||||
- ``unique_counts``:
|
||||
*(returned only if return_counts is True)* An array of shape ``(n_unique,)``.
|
||||
Contains the number of occurances of each unique value in ``ar``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.unique_counts`: shortcut to ``unique(arr, return_counts=True)``.
|
||||
- :func:`jax.numpy.unique_inverse`: shortcut to ``unique(arr, return_inverse=True)``.
|
||||
- :func:`jax.numpy.unique_all`: shortcut to ``unique`` with all return values.
|
||||
- :func:`jax.numpy.unique_values`: like ``unique``, but no optional return values.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> jnp.unique(x)
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
|
||||
**JIT compilation & the size argument**
|
||||
|
||||
If you try this under :func:`~jax.jit` or another transformation, you will get an
|
||||
error because the output shape is dynamic:
|
||||
|
||||
>>> jax.jit(jnp.unique)(x) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5].
|
||||
The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.
|
||||
|
||||
The issue is that the output of transformed functions must have static shapes.
|
||||
In order to make this work, you can pass a static ``size`` parameter:
|
||||
|
||||
>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size'])
|
||||
>>> jit_unique(x, size=3)
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
|
||||
If your static size is smaller than the true number of unique values, they will be truncated.
|
||||
|
||||
>>> jit_unique(x, size=2)
|
||||
Array([1, 3], dtype=int32)
|
||||
|
||||
If the static size is larger than the true number of unique values, they will be padded with
|
||||
``fill_value``, which defaults to the minimum unique value:
|
||||
|
||||
>>> jit_unique(x, size=5)
|
||||
Array([1, 3, 4, 1, 1], dtype=int32)
|
||||
>>> jit_unique(x, size=5, fill_value=0)
|
||||
Array([1, 3, 4, 0, 0], dtype=int32)
|
||||
|
||||
**Multi-dimensional unique values**
|
||||
|
||||
If you pass a multi-dimensional array to ``unique``, it will be flattened by default:
|
||||
|
||||
>>> M = jnp.array([[1, 2],
|
||||
... [2, 3],
|
||||
... [1, 2]])
|
||||
>>> jnp.unique(M)
|
||||
Array([1, 2, 3], dtype=int32)
|
||||
|
||||
If you pass an ``axis`` keyword, you can find unique *slices* of the array along
|
||||
that axis:
|
||||
|
||||
>>> jnp.unique(M, axis=0)
|
||||
Array([[1, 2],
|
||||
[2, 3]], dtype=int32)
|
||||
|
||||
**Returning indices**
|
||||
|
||||
If you set ``return_index=True``, then ``unique`` returns the indices of the
|
||||
first occurance of each unique value:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> values, indices = jnp.unique(x, return_index=True)
|
||||
>>> print(values)
|
||||
[1 3 4]
|
||||
>>> print(indices)
|
||||
[2 0 1]
|
||||
>>> jnp.all(values == x[indices])
|
||||
Array(True, dtype=bool)
|
||||
|
||||
In multiple dimensions, the unique values can be extracted with :func:`jax.numpy.take`
|
||||
evaluated along the specified axis:
|
||||
|
||||
>>> values, indices = jnp.unique(M, axis=0, return_index=True)
|
||||
>>> jnp.all(values == jnp.take(M, indices, axis=0))
|
||||
Array(True, dtype=bool)
|
||||
|
||||
**Returning inverse**
|
||||
|
||||
If you set ``return_inverse=True``, then ``unique`` returns the indices within the
|
||||
unique values for every entry in the input array:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> values, inverse = jnp.unique(x, return_inverse=True)
|
||||
>>> print(values)
|
||||
[1 3 4]
|
||||
>>> print(inverse)
|
||||
[1 2 0 1 0]
|
||||
>>> jnp.all(values[inverse] == x)
|
||||
Array(True, dtype=bool)
|
||||
|
||||
In multiple dimensions, the input can be reconstructed using
|
||||
:func:`jax.numpy.take_along_axis`:
|
||||
|
||||
>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
|
||||
>>> jnp.all(jnp.take_along_axis(values, inverse, axis=0) == M)
|
||||
Array(True, dtype=bool)
|
||||
|
||||
**Returning counts**
|
||||
|
||||
If you set ``return_counts=True``, then ``unique`` returns the number of occurances
|
||||
within the input for every unique value:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> values, counts = jnp.unique(x, return_counts=True)
|
||||
>>> print(values)
|
||||
[1 3 4]
|
||||
>>> print(counts)
|
||||
[2 2 1]
|
||||
|
||||
For multi-dimensional arrays, this also returns a 1D array of counts
|
||||
indicating number of occurances along the specified axis:
|
||||
|
||||
>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
|
||||
>>> print(values)
|
||||
[[1 2]
|
||||
[2 3]]
|
||||
>>> print(counts)
|
||||
[2 1]
|
||||
"""
|
||||
check_arraylike("unique", ar)
|
||||
if size is None:
|
||||
ar = core.concrete_or_error(None, ar,
|
||||
@ -352,6 +703,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
|
||||
|
||||
class _UniqueAllResult(NamedTuple):
|
||||
"""Struct returned by :func:`jax.numpy.unique_all`."""
|
||||
values: Array
|
||||
indices: Array
|
||||
inverse_indices: Array
|
||||
@ -359,38 +711,260 @@ class _UniqueAllResult(NamedTuple):
|
||||
|
||||
|
||||
class _UniqueCountsResult(NamedTuple):
|
||||
"""Struct returned by :func:`jax.numpy.unique_counts`."""
|
||||
values: Array
|
||||
counts: Array
|
||||
|
||||
|
||||
class _UniqueInverseResult(NamedTuple):
|
||||
"""Struct returned by :func:`jax.numpy.unique_inverse`."""
|
||||
values: Array
|
||||
inverse_indices: Array
|
||||
|
||||
|
||||
@implements(getattr(np, "unique_all", None))
|
||||
def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
|
||||
def unique_all(x: ArrayLike, /, *, size: int | None = None,
|
||||
fill_value: ArrayLike | None = None) -> _UniqueAllResult:
|
||||
"""Return unique values from x, along with indices, inverse indices, and counts.
|
||||
|
||||
JAX implementation of :func:`numpy.unique_all`; this is equivalent to calling
|
||||
:func:`jax.numpy.unique` with `return_index`, `return_inverse`, `return_counts`,
|
||||
and `equal_nan` set to True.
|
||||
|
||||
Because the size of the output of ``unique_all`` is data-dependent, the function
|
||||
semantics are not typically compatible with :func:`~jax.jit` and other JAX
|
||||
transformations. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used in such contexts.
|
||||
|
||||
Args:
|
||||
x: N-dimensional array from which unique values will be extracted.
|
||||
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
|
||||
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
|
||||
|
||||
Returns:
|
||||
A tuple ``(values, indices, inverse_indices, counts)``, with the following properties:
|
||||
|
||||
- ``values``:
|
||||
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||
- ``indices``:
|
||||
An array of shape ``(n_unique,)``. Contains the indices of the first occurance of
|
||||
each unique value in ``x``. For 1D inputs, ``x[indices]`` is equivlent to ``values``.
|
||||
- ``inverse_indices``:
|
||||
An array of shape ``x.shape``. Contains the indices within ``values`` of each value
|
||||
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
|
||||
- ``counts``:
|
||||
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
|
||||
value in ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.unique`: general function for computing unique values.
|
||||
- :func:`jax.numpy.unique_values`: compute only ``values``.
|
||||
- :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``.
|
||||
- :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``.
|
||||
|
||||
Examples:
|
||||
Here we compute the unique values in a 1D array:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> result = jnp.unique_all(x)
|
||||
|
||||
The result is a :class:`~typing.NamedTuple` with four named attributes.
|
||||
The ``values`` attribue contains the unique values from the array:
|
||||
|
||||
>>> result.values
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
|
||||
The ``indices`` attribute contains the indices of the unique ``values`` within
|
||||
the input array:
|
||||
|
||||
>>> result.indices
|
||||
Array([2, 0, 1], dtype=int32)
|
||||
>>> jnp.all(result.values == x[result.indices])
|
||||
Array(True, dtype=bool)
|
||||
|
||||
The ``inverse_indices`` attribute contains the indices of the input within ``values``:
|
||||
|
||||
>>> result.inverse_indices
|
||||
Array([1, 2, 0, 1, 0], dtype=int32)
|
||||
>>> jnp.all(x == result.values[result.inverse_indices])
|
||||
Array(True, dtype=bool)
|
||||
|
||||
The ``counts`` attribute contains the counts of each unique value in the input:
|
||||
|
||||
>>> result.counts
|
||||
Array([2, 2, 1], dtype=int32)
|
||||
|
||||
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
|
||||
"""
|
||||
check_arraylike("unique_all", x)
|
||||
values, indices, inverse_indices, counts = unique(
|
||||
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False)
|
||||
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False,
|
||||
size=size, fill_value=fill_value)
|
||||
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
|
||||
|
||||
|
||||
@implements(getattr(np, "unique_counts", None))
|
||||
def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
|
||||
def unique_counts(x: ArrayLike, /, *, size: int | None = None,
|
||||
fill_value: ArrayLike | None = None) -> _UniqueCountsResult:
|
||||
"""Return unique values from x, along with counts.
|
||||
|
||||
JAX implementation of :func:`numpy.unique_counts`; this is equivalent to calling
|
||||
:func:`jax.numpy.unique` with `return_counts` and `equal_nan` set to True.
|
||||
|
||||
Because the size of the output of ``unique_counts`` is data-dependent, the function
|
||||
semantics are not typically compatible with :func:`~jax.jit` and other JAX
|
||||
transformations. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used in such contexts.
|
||||
|
||||
Args:
|
||||
x: N-dimensional array from which unique values will be extracted.
|
||||
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
|
||||
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
|
||||
|
||||
Returns:
|
||||
A tuple ``(values, counts)``, with the following properties:
|
||||
|
||||
- ``values``:
|
||||
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||
- ``counts``:
|
||||
An array of shape ``(n_unique,)``. Contains the number of occurances of each unique
|
||||
value in ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.unique`: general function for computing unique values.
|
||||
- :func:`jax.numpy.unique_values`: compute only ``values``.
|
||||
- :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``.
|
||||
- :func:`jax.numpy.unique_all`: compute ``values``, ``indices``, ``inverse_indices``,
|
||||
and ``counts``.
|
||||
|
||||
Examples:
|
||||
Here we compute the unique values in a 1D array:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> result = jnp.unique_counts(x)
|
||||
|
||||
The result is a :class:`~typing.NamedTuple` with two named attributes.
|
||||
The ``values`` attribue contains the unique values from the array:
|
||||
|
||||
>>> result.values
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
|
||||
The ``counts`` attribute contains the counts of each unique value in the input:
|
||||
|
||||
>>> result.counts
|
||||
Array([2, 2, 1], dtype=int32)
|
||||
|
||||
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
|
||||
"""
|
||||
check_arraylike("unique_counts", x)
|
||||
values, counts = unique(x, return_counts=True, equal_nan=False)
|
||||
values, counts = unique(x, return_counts=True, equal_nan=False,
|
||||
size=size, fill_value=fill_value)
|
||||
return _UniqueCountsResult(values=values, counts=counts)
|
||||
|
||||
|
||||
@implements(getattr(np, "unique_inverse", None))
|
||||
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
|
||||
def unique_inverse(x: ArrayLike, /, *, size: int | None = None,
|
||||
fill_value: ArrayLike | None = None) -> _UniqueInverseResult:
|
||||
"""Return unique values from x, along with indices, inverse indices, and counts.
|
||||
|
||||
JAX implementation of :func:`numpy.unique_inverse`; this is equivalent to calling
|
||||
:func:`jax.numpy.unique` with `return_inverse` and `equal_nan` set to True.
|
||||
|
||||
Because the size of the output of ``unique_inverse`` is data-dependent, the function
|
||||
semantics are not typically compatible with :func:`~jax.jit` and other JAX
|
||||
transformations. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used in such contexts.
|
||||
|
||||
Args:
|
||||
x: N-dimensional array from which unique values will be extracted.
|
||||
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
|
||||
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
|
||||
|
||||
Returns:
|
||||
A tuple ``(values, indices, inverse_indices, counts)``, with the following properties:
|
||||
|
||||
- ``values``:
|
||||
an array of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||
- ``inverse_indices``:
|
||||
An array of shape ``x.shape``. Contains the indices within ``values`` of each value
|
||||
in ``x``. For 1D inputs, ``values[inverse_indices]`` is equivalent to ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.unique`: general function for computing unique values.
|
||||
- :func:`jax.numpy.unique_values`: compute only ``values``.
|
||||
- :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``.
|
||||
- :func:`jax.numpy.unique_all`: compute ``values``, ``indices``, ``inverse_indices``,
|
||||
and ``counts``.
|
||||
|
||||
Examples:
|
||||
Here we compute the unique values in a 1D array:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> result = jnp.unique_inverse(x)
|
||||
|
||||
The result is a :class:`~typing.NamedTuple` with two named attributes.
|
||||
The ``values`` attribue contains the unique values from the array:
|
||||
|
||||
>>> result.values
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
|
||||
The ``indices`` attribute contains the indices of the unique ``values`` within
|
||||
the input array:
|
||||
|
||||
The ``inverse_indices`` attribute contains the indices of the input within ``values``:
|
||||
|
||||
>>> result.inverse_indices
|
||||
Array([1, 2, 0, 1, 0], dtype=int32)
|
||||
>>> jnp.all(x == result.values[result.inverse_indices])
|
||||
Array(True, dtype=bool)
|
||||
|
||||
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
|
||||
"""
|
||||
check_arraylike("unique_inverse", x)
|
||||
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
|
||||
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False,
|
||||
size=size, fill_value=fill_value)
|
||||
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)
|
||||
|
||||
|
||||
@implements(getattr(np, "unique_values", None))
|
||||
def unique_values(x: ArrayLike, /) -> Array:
|
||||
def unique_values(x: ArrayLike, /, *, size: int | None = None,
|
||||
fill_value: ArrayLike | None = None) -> Array:
|
||||
"""Return unique values from x, along with indices, inverse indices, and counts.
|
||||
|
||||
JAX implementation of :func:`numpy.unique_values`; this is equivalent to calling
|
||||
:func:`jax.numpy.unique` with `equal_nan` set to True.
|
||||
|
||||
Because the size of the output of ``unique_values`` is data-dependent, the function
|
||||
semantics are not typically compatible with :func:`~jax.jit` and other JAX
|
||||
transformations. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used in such contexts.
|
||||
|
||||
Args:
|
||||
x: N-dimensional array from which unique values will be extracted.
|
||||
size: if specified, return only the first ``size`` sorted unique elements. If there are fewer
|
||||
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
|
||||
|
||||
Returns:
|
||||
An array ``values`` of shape ``(n_unique,)`` containing the unique values from ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.unique`: general function for computing unique values.
|
||||
- :func:`jax.numpy.unique_values`: compute only ``values``.
|
||||
- :func:`jax.numpy.unique_counts`: compute only ``values`` and ``counts``.
|
||||
- :func:`jax.numpy.unique_inverse`: compute only ``values`` and ``inverse``.
|
||||
|
||||
Examples:
|
||||
Here we compute the unique values in a 1D array:
|
||||
|
||||
>>> x = jnp.array([3, 4, 1, 3, 1])
|
||||
>>> jnp.unique_values(x)
|
||||
Array([1, 3, 4], dtype=int32)
|
||||
|
||||
For examples of the ``size`` and ``fill_value`` arguments, see :func:`jax.numpy.unique`.
|
||||
"""
|
||||
check_arraylike("unique_values", x)
|
||||
return cast(Array, unique(x, equal_nan=False))
|
||||
return cast(Array, unique(x, equal_nan=False, size=size, fill_value=fill_value))
|
||||
|
@ -870,10 +870,14 @@ def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: bui
|
||||
*, equal_nan: builtins.bool = ..., size: Optional[int] = ...,
|
||||
fill_value: Optional[ArrayLike] = ...
|
||||
): ...
|
||||
def unique_all(x: ArrayLike, /) -> _UniqueAllResult: ...
|
||||
def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: ...
|
||||
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: ...
|
||||
def unique_values(x: ArrayLike, /) -> Array: ...
|
||||
def unique_all(x: ArrayLike, /, *, size: Optional[int] = ...,
|
||||
fill_value: Optional[ArrayLike] = ...) -> _UniqueAllResult: ...
|
||||
def unique_counts(x: ArrayLike, /, *, size: Optional[int] = ...,
|
||||
fill_value: Optional[ArrayLike] = ...) -> _UniqueCountsResult: ...
|
||||
def unique_inverse(x: ArrayLike, /, *, size: Optional[int] = ...,
|
||||
fill_value: Optional[ArrayLike] = ...) -> _UniqueInverseResult: ...
|
||||
def unique_values(x: ArrayLike, /, *, size: Optional[int] = ...,
|
||||
fill_value: Optional[ArrayLike] = ...) -> Array: ...
|
||||
def unpackbits(
|
||||
a: ArrayLike,
|
||||
axis: Optional[int] = ...,
|
||||
|
Loading…
x
Reference in New Issue
Block a user