mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
lax_numpy: refactor set operations into separate private submodule
This commit is contained in:
parent
1ffa285bd6
commit
466bea1662
@ -953,176 +953,6 @@ def interp(x, xp, fp, left=None, right=None, period=None):
|
||||
return f
|
||||
|
||||
|
||||
@_wraps(np.in1d, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
@partial(jit, static_argnames=('assume_unique', 'invert',))
|
||||
def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811
|
||||
_check_arraylike("in1d", ar1, ar2)
|
||||
ar1 = ravel(ar1)
|
||||
ar2 = ravel(ar2)
|
||||
# Note: an algorithm based on searchsorted has better scaling, but in practice
|
||||
# is very slow on accelerators because it relies on lax control flow. If XLA
|
||||
# ever supports binary search natively, we should switch to this:
|
||||
# ar2 = jnp.sort(ar2)
|
||||
# ind = jnp.searchsorted(ar2, ar1)
|
||||
# if invert:
|
||||
# return ar1 != ar2[ind]
|
||||
# else:
|
||||
# return ar1 == ar2[ind]
|
||||
if invert:
|
||||
return (ar1[:, None] != ar2[None, :]).all(-1)
|
||||
else:
|
||||
return (ar1[:, None] == ar2[None, :]).any(-1)
|
||||
|
||||
@_wraps(np.setdiff1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.setdiff1d`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` elements of the result will be returned. If there are
|
||||
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
|
||||
_check_arraylike("setdiff1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()")
|
||||
ar1 = asarray(ar1)
|
||||
fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype)
|
||||
if ar1.size == 0:
|
||||
return full_like(ar1, fill_value, shape=size or 0)
|
||||
if not assume_unique:
|
||||
ar1 = unique(ar1, size=size and ar1.size)
|
||||
mask = in1d(ar1, ar2, invert=True)
|
||||
if size is None:
|
||||
return ar1[mask]
|
||||
else:
|
||||
if not (assume_unique or size is None):
|
||||
# Set mask to zero at locations corresponding to unique() padding.
|
||||
n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum()
|
||||
mask = where(arange(ar1.size) < n_unique, mask, False)
|
||||
return where(arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
|
||||
|
||||
|
||||
@_wraps(np.union1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``union1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.union1d`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` elements of the result will be returned. If there are
|
||||
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to the minimum
|
||||
value of the union."""))
|
||||
def union1d(ar1, ar2, *, size=None, fill_value=None):
|
||||
_check_arraylike("union1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()")
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size, "The error arose in union1d()")
|
||||
return unique(concatenate((ar1, ar2), axis=None), size=size, fill_value=fill_value)
|
||||
|
||||
|
||||
@_wraps(np.setxor1d, lax_description="""
|
||||
In the JAX version, the input arrays are explicitly flattened regardless
|
||||
of assume_unique value.
|
||||
""")
|
||||
def setxor1d(ar1, ar2, assume_unique=False):
|
||||
_check_arraylike("setxor1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
|
||||
|
||||
ar1 = ravel(ar1)
|
||||
ar2 = ravel(ar2)
|
||||
|
||||
if not assume_unique:
|
||||
ar1 = unique(ar1)
|
||||
ar2 = unique(ar2)
|
||||
|
||||
aux = concatenate((ar1, ar2))
|
||||
if aux.size == 0:
|
||||
return aux
|
||||
|
||||
aux = sort(aux)
|
||||
flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
|
||||
return aux[flag[1:] & flag[:-1]]
|
||||
|
||||
|
||||
@partial(jit, static_argnums=2)
|
||||
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
|
||||
"""
|
||||
Helper function for intersect1d which is jit-able
|
||||
"""
|
||||
ar = concatenate((ar1, ar2))
|
||||
if return_indices:
|
||||
iota = lax.broadcasted_iota(np.int64, shape(ar), dimension=0)
|
||||
aux, indices = lax.sort_key_val(ar, iota)
|
||||
else:
|
||||
aux = sort(ar)
|
||||
|
||||
mask = aux[1:] == aux[:-1]
|
||||
if return_indices:
|
||||
return aux, mask, indices
|
||||
else:
|
||||
return aux, mask
|
||||
|
||||
|
||||
@_wraps(np.intersect1d)
|
||||
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
||||
_check_arraylike("intersect1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
|
||||
|
||||
if not assume_unique:
|
||||
if return_indices:
|
||||
ar1, ind1 = unique(ar1, return_index=True)
|
||||
ar2, ind2 = unique(ar2, return_index=True)
|
||||
else:
|
||||
ar1 = unique(ar1)
|
||||
ar2 = unique(ar2)
|
||||
else:
|
||||
ar1 = ravel(ar1)
|
||||
ar2 = ravel(ar2)
|
||||
|
||||
if return_indices:
|
||||
aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices)
|
||||
else:
|
||||
aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)
|
||||
|
||||
int1d = aux[:-1][mask]
|
||||
|
||||
if return_indices:
|
||||
ar1_indices = aux_sort_indices[:-1][mask]
|
||||
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
|
||||
if not assume_unique:
|
||||
ar1_indices = ind1[ar1_indices]
|
||||
ar2_indices = ind2[ar2_indices]
|
||||
|
||||
return int1d, ar1_indices, ar2_indices
|
||||
else:
|
||||
return int1d
|
||||
|
||||
|
||||
@_wraps(np.isin, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
def isin(element, test_elements, assume_unique=False, invert=False): # noqa: F811
|
||||
result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert)
|
||||
return result.reshape(shape(element))
|
||||
|
||||
|
||||
@_wraps(np.where,
|
||||
lax_description=_dedent("""
|
||||
At present, JAX does not support JIT-compilation of the single-argument form
|
||||
@ -3602,116 +3432,6 @@ def take_along_axis(arr, indices, axis: Optional[int]):
|
||||
start_index_map=tuple(start_index_map))
|
||||
return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes))
|
||||
|
||||
|
||||
### SetOps
|
||||
@partial(jit, static_argnums=1)
|
||||
def _unique_sorted_mask(ar, axis):
|
||||
aux = moveaxis(ar, axis, 0)
|
||||
if issubdtype(aux.dtype, np.complexfloating):
|
||||
# Work around issue in sorting of complex numbers with Nan only in the
|
||||
# imaginary component. This can be removed if sorting in this situation
|
||||
# is fixed to match numpy.
|
||||
aux = where(isnan(aux), _lax_const(aux, nan), aux)
|
||||
size, *out_shape = aux.shape
|
||||
if _prod(out_shape) == 0:
|
||||
size = 1
|
||||
perm = zeros(1, dtype=int)
|
||||
else:
|
||||
perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
|
||||
aux = aux[perm]
|
||||
if aux.size:
|
||||
if issubdtype(aux.dtype, inexact):
|
||||
# This is appropriate for both float and complex due to the documented behavior of np.unique:
|
||||
# See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
|
||||
neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
|
||||
else:
|
||||
neq = lax.ne
|
||||
mask = ones(size, dtype=bool).at[1:].set(any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
|
||||
else:
|
||||
mask = zeros(size, dtype=bool)
|
||||
return aux, mask, perm
|
||||
|
||||
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False,
|
||||
size=None, fill_value=None, return_true_size=False):
|
||||
"""
|
||||
Find the unique elements of an array along a particular axis.
|
||||
"""
|
||||
if ar.shape[axis] == 0 and size and fill_value is None:
|
||||
raise ValueError(
|
||||
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
|
||||
|
||||
aux, mask, perm = _unique_sorted_mask(ar, axis)
|
||||
ind = mask if size is None else nonzero(mask, size=size)[0]
|
||||
result = aux[ind] if aux.size else aux
|
||||
if fill_value is not None:
|
||||
fill_value = asarray(fill_value, dtype=result.dtype)
|
||||
if size is not None and fill_value is not None:
|
||||
if result.shape[0]:
|
||||
valid = lax.expand_dims(arange(size) < mask.sum(), tuple(range(1, result.ndim)))
|
||||
result = where(valid, result, fill_value)
|
||||
else:
|
||||
result = full_like(result, fill_value, shape=(size, *result.shape[1:]))
|
||||
result = moveaxis(result, 0, axis)
|
||||
|
||||
ret = (result,)
|
||||
if return_index:
|
||||
if aux.size:
|
||||
ret += (perm[ind],)
|
||||
else:
|
||||
ret += (perm,)
|
||||
if return_inverse:
|
||||
if aux.size:
|
||||
imask = cumsum(mask) - 1
|
||||
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
|
||||
inv_idx = inv_idx.at[perm].set(imask)
|
||||
else:
|
||||
inv_idx = zeros(ar.shape[axis], dtype=int)
|
||||
ret += (inv_idx,)
|
||||
if return_counts:
|
||||
if aux.size:
|
||||
if size is None:
|
||||
idx = append(nonzero(mask)[0], mask.size)
|
||||
else:
|
||||
idx = nonzero(mask, size=size + 1)[0]
|
||||
idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
|
||||
ret += (diff(idx),)
|
||||
elif ar.shape[axis]:
|
||||
ret += (array([ar.shape[axis]], dtype=dtypes.canonicalize_dtype(int_)),)
|
||||
else:
|
||||
ret += (empty(0, dtype=int),)
|
||||
if return_true_size:
|
||||
# Useful for internal uses of unique().
|
||||
ret += (mask.sum(),)
|
||||
return ret[0] if len(ret) == 1 else ret
|
||||
|
||||
@_wraps(np.unique, skip_params=['axis'],
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``unique`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` unique elements will be returned. If there are fewer unique
|
||||
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``. The default is the minimum value
|
||||
along the specified axis of the input."""))
|
||||
def unique(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
|
||||
_check_arraylike("unique", ar)
|
||||
if size is None:
|
||||
ar = core.concrete_or_error(None, ar, "The error arose for the first argument of jnp.unique()")
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size, "The error arose for the size argument of jnp.unique()")
|
||||
ar = asarray(ar)
|
||||
if axis is None:
|
||||
axis = 0
|
||||
ar = ar.flatten()
|
||||
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
|
||||
return _unique(ar, axis, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
|
||||
|
||||
### Indexing
|
||||
|
||||
def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
|
315
jax/_src/numpy/setops.py
Normal file
315
jax/_src/numpy/setops.py
Normal file
@ -0,0 +1,315 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from textwrap import dedent as _dedent
|
||||
from typing import Optional
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
any, append, arange, array, asarray, concatenate, cumsum, diff,
|
||||
empty, full_like, isnan, lexsort, moveaxis, nonzero, ones, ravel,
|
||||
sort, where, zeros)
|
||||
from jax._src.numpy.util import _check_arraylike, _wraps
|
||||
from jax._src.util import prod as _prod
|
||||
from jax import core
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
import numpy as np
|
||||
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
|
||||
|
||||
@_wraps(np.in1d, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
@partial(jit, static_argnames=('assume_unique', 'invert',))
|
||||
def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811
|
||||
del assume_unique # unused
|
||||
_check_arraylike("in1d", ar1, ar2)
|
||||
ar1 = ravel(ar1)
|
||||
ar2 = ravel(ar2)
|
||||
# Note: an algorithm based on searchsorted has better scaling, but in practice
|
||||
# is very slow on accelerators because it relies on lax control flow. If XLA
|
||||
# ever supports binary search natively, we should switch to this:
|
||||
# ar2 = jnp.sort(ar2)
|
||||
# ind = jnp.searchsorted(ar2, ar1)
|
||||
# if invert:
|
||||
# return ar1 != ar2[ind]
|
||||
# else:
|
||||
# return ar1 == ar2[ind]
|
||||
if invert:
|
||||
return (ar1[:, None] != ar2[None, :]).all(-1)
|
||||
else:
|
||||
return (ar1[:, None] == ar2[None, :]).any(-1)
|
||||
|
||||
@_wraps(np.setdiff1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.setdiff1d`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` elements of the result will be returned. If there are
|
||||
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
|
||||
_check_arraylike("setdiff1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()")
|
||||
ar1 = asarray(ar1)
|
||||
fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype)
|
||||
if ar1.size == 0:
|
||||
return full_like(ar1, fill_value, shape=size or 0)
|
||||
if not assume_unique:
|
||||
ar1 = unique(ar1, size=size and ar1.size)
|
||||
mask = in1d(ar1, ar2, invert=True)
|
||||
if size is None:
|
||||
return ar1[mask]
|
||||
else:
|
||||
if not (assume_unique or size is None):
|
||||
# Set mask to zero at locations corresponding to unique() padding.
|
||||
n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum()
|
||||
mask = where(arange(ar1.size) < n_unique, mask, False)
|
||||
return where(arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
|
||||
|
||||
|
||||
@_wraps(np.union1d,
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``union1d`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.union1d`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` elements of the result will be returned. If there are
|
||||
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to the minimum
|
||||
value of the union."""))
|
||||
def union1d(ar1, ar2, *, size=None, fill_value=None):
|
||||
_check_arraylike("union1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()")
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size, "The error arose in union1d()")
|
||||
return unique(concatenate((ar1, ar2), axis=None), size=size, fill_value=fill_value)
|
||||
|
||||
|
||||
@_wraps(np.setxor1d, lax_description="""
|
||||
In the JAX version, the input arrays are explicitly flattened regardless
|
||||
of assume_unique value.
|
||||
""")
|
||||
def setxor1d(ar1, ar2, assume_unique=False):
|
||||
_check_arraylike("setxor1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
|
||||
|
||||
ar1 = ravel(ar1)
|
||||
ar2 = ravel(ar2)
|
||||
|
||||
if not assume_unique:
|
||||
ar1 = unique(ar1)
|
||||
ar2 = unique(ar2)
|
||||
|
||||
aux = concatenate((ar1, ar2))
|
||||
if aux.size == 0:
|
||||
return aux
|
||||
|
||||
aux = sort(aux)
|
||||
flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
|
||||
return aux[flag[1:] & flag[:-1]]
|
||||
|
||||
|
||||
@partial(jit, static_argnums=2)
|
||||
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
|
||||
"""
|
||||
Helper function for intersect1d which is jit-able
|
||||
"""
|
||||
ar = concatenate((ar1, ar2))
|
||||
if return_indices:
|
||||
iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0)
|
||||
aux, indices = lax.sort_key_val(ar, iota)
|
||||
else:
|
||||
aux = sort(ar)
|
||||
|
||||
mask = aux[1:] == aux[:-1]
|
||||
if return_indices:
|
||||
return aux, mask, indices
|
||||
else:
|
||||
return aux, mask
|
||||
|
||||
|
||||
@_wraps(np.intersect1d)
|
||||
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
||||
_check_arraylike("intersect1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
|
||||
|
||||
if not assume_unique:
|
||||
if return_indices:
|
||||
ar1, ind1 = unique(ar1, return_index=True)
|
||||
ar2, ind2 = unique(ar2, return_index=True)
|
||||
else:
|
||||
ar1 = unique(ar1)
|
||||
ar2 = unique(ar2)
|
||||
else:
|
||||
ar1 = ravel(ar1)
|
||||
ar2 = ravel(ar2)
|
||||
|
||||
if return_indices:
|
||||
aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices)
|
||||
else:
|
||||
aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)
|
||||
|
||||
int1d = aux[:-1][mask]
|
||||
|
||||
if return_indices:
|
||||
ar1_indices = aux_sort_indices[:-1][mask]
|
||||
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
|
||||
if not assume_unique:
|
||||
ar1_indices = ind1[ar1_indices]
|
||||
ar2_indices = ind2[ar2_indices]
|
||||
|
||||
return int1d, ar1_indices, ar2_indices
|
||||
else:
|
||||
return int1d
|
||||
|
||||
|
||||
@_wraps(np.isin, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
def isin(element, test_elements, assume_unique=False, invert=False): # noqa: F811
|
||||
result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert)
|
||||
return result.reshape(np.shape(element))
|
||||
|
||||
|
||||
### SetOps
|
||||
@partial(jit, static_argnums=1)
|
||||
def _unique_sorted_mask(ar, axis):
|
||||
aux = moveaxis(ar, axis, 0)
|
||||
if np.issubdtype(aux.dtype, np.complexfloating):
|
||||
# Work around issue in sorting of complex numbers with Nan only in the
|
||||
# imaginary component. This can be removed if sorting in this situation
|
||||
# is fixed to match numpy.
|
||||
aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
|
||||
size, *out_shape = aux.shape
|
||||
if _prod(out_shape) == 0:
|
||||
size = 1
|
||||
perm = zeros(1, dtype=int)
|
||||
else:
|
||||
perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
|
||||
aux = aux[perm]
|
||||
if aux.size:
|
||||
if dtypes.issubdtype(aux.dtype, np.inexact):
|
||||
# This is appropriate for both float and complex due to the documented behavior of np.unique:
|
||||
# See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
|
||||
neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
|
||||
else:
|
||||
neq = lax.ne
|
||||
mask = ones(size, dtype=bool).at[1:].set(any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
|
||||
else:
|
||||
mask = zeros(size, dtype=bool)
|
||||
return aux, mask, perm
|
||||
|
||||
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False,
|
||||
size=None, fill_value=None, return_true_size=False):
|
||||
"""
|
||||
Find the unique elements of an array along a particular axis.
|
||||
"""
|
||||
if ar.shape[axis] == 0 and size and fill_value is None:
|
||||
raise ValueError(
|
||||
"jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified")
|
||||
|
||||
aux, mask, perm = _unique_sorted_mask(ar, axis)
|
||||
ind = mask if size is None else nonzero(mask, size=size)[0]
|
||||
result = aux[ind] if aux.size else aux
|
||||
if fill_value is not None:
|
||||
fill_value = asarray(fill_value, dtype=result.dtype)
|
||||
if size is not None and fill_value is not None:
|
||||
if result.shape[0]:
|
||||
valid = lax.expand_dims(arange(size) < mask.sum(), tuple(range(1, result.ndim)))
|
||||
result = where(valid, result, fill_value)
|
||||
else:
|
||||
result = full_like(result, fill_value, shape=(size, *result.shape[1:]))
|
||||
result = moveaxis(result, 0, axis)
|
||||
|
||||
ret = (result,)
|
||||
if return_index:
|
||||
if aux.size:
|
||||
ret += (perm[ind],)
|
||||
else:
|
||||
ret += (perm,)
|
||||
if return_inverse:
|
||||
if aux.size:
|
||||
imask = cumsum(mask) - 1
|
||||
inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(dtypes.int_))
|
||||
inv_idx = inv_idx.at[perm].set(imask)
|
||||
else:
|
||||
inv_idx = zeros(ar.shape[axis], dtype=int)
|
||||
ret += (inv_idx,)
|
||||
if return_counts:
|
||||
if aux.size:
|
||||
if size is None:
|
||||
idx = append(nonzero(mask)[0], mask.size)
|
||||
else:
|
||||
idx = nonzero(mask, size=size + 1)[0]
|
||||
idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size))
|
||||
ret += (diff(idx),)
|
||||
elif ar.shape[axis]:
|
||||
ret += (array([ar.shape[axis]], dtype=dtypes.canonicalize_dtype(dtypes.int_)),)
|
||||
else:
|
||||
ret += (empty(0, dtype=int),)
|
||||
if return_true_size:
|
||||
# Useful for internal uses of unique().
|
||||
ret += (mask.sum(),)
|
||||
return ret[0] if len(ret) == 1 else ret
|
||||
|
||||
@_wraps(np.unique, skip_params=['axis'],
|
||||
lax_description=_dedent("""
|
||||
Because the size of the output of ``unique`` is data-dependent, the function is not
|
||||
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
|
||||
must be specified statically for ``jnp.unique`` to be used within some of JAX's
|
||||
transformations."""),
|
||||
extra_params=_dedent("""
|
||||
size : int, optional
|
||||
If specified, the first ``size`` unique elements will be returned. If there are fewer unique
|
||||
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``. The default is the minimum value
|
||||
along the specified axis of the input."""))
|
||||
def unique(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
|
||||
_check_arraylike("unique", ar)
|
||||
if size is None:
|
||||
ar = core.concrete_or_error(None, ar, "The error arose for the first argument of jnp.unique()")
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size, "The error arose for the size argument of jnp.unique()")
|
||||
ar = asarray(ar)
|
||||
if axis is None:
|
||||
axis = 0
|
||||
ar = ar.flatten()
|
||||
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
|
||||
return _unique(ar, axis, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
|
@ -40,7 +40,7 @@ from jax._src.lax.lax import (
|
||||
DotDimensionNumbers)
|
||||
from jax._src.lib import cusparse
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy.lax_numpy import _unique
|
||||
from jax._src.numpy.setops import _unique
|
||||
|
||||
xops = xc._xla.ops
|
||||
|
||||
|
@ -142,7 +142,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
iinfo as iinfo,
|
||||
indices as indices,
|
||||
inexact as inexact,
|
||||
in1d as in1d,
|
||||
inf as inf,
|
||||
inner as inner,
|
||||
insert as insert,
|
||||
@ -153,11 +152,9 @@ from jax._src.numpy.lax_numpy import (
|
||||
int_ as int_,
|
||||
integer as integer,
|
||||
interp as interp,
|
||||
intersect1d as intersect1d,
|
||||
isclose as isclose,
|
||||
iscomplex as iscomplex,
|
||||
iscomplexobj as iscomplexobj,
|
||||
isin as isin,
|
||||
isreal as isreal,
|
||||
isrealobj as isrealobj,
|
||||
isscalar as isscalar,
|
||||
@ -234,8 +231,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
searchsorted as searchsorted,
|
||||
select as select,
|
||||
set_printoptions as set_printoptions,
|
||||
setdiff1d as setdiff1d,
|
||||
setxor1d as setxor1d,
|
||||
shape as shape,
|
||||
signedinteger as signedinteger,
|
||||
single as single,
|
||||
@ -270,8 +265,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
uint32 as uint32,
|
||||
uint64 as uint64,
|
||||
uint8 as uint8,
|
||||
unique as unique,
|
||||
union1d as union1d,
|
||||
unpackbits as unpackbits,
|
||||
unravel_index as unravel_index,
|
||||
unsignedinteger as unsignedinteger,
|
||||
@ -340,6 +333,16 @@ from jax._src.numpy.reductions import (
|
||||
var as var,
|
||||
)
|
||||
|
||||
from jax._src.numpy.setops import (
|
||||
in1d as in1d,
|
||||
intersect1d as intersect1d,
|
||||
isin as isin,
|
||||
setdiff1d as setdiff1d,
|
||||
setxor1d as setxor1d,
|
||||
union1d as union1d,
|
||||
unique as unique,
|
||||
)
|
||||
|
||||
from jax._src.numpy.ufuncs import (
|
||||
abs as abs,
|
||||
absolute as absolute,
|
||||
|
Loading…
x
Reference in New Issue
Block a user