mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[typing] add full annotations for lax_numpy setops
This commit is contained in:
parent
be56a3559d
commit
fd45035b21
@ -1653,7 +1653,7 @@ def tile(A, reps):
|
||||
[k for pair in zip(reps, A_shape) for k in pair])
|
||||
return reshape(result, tuple(np.multiply(A_shape, reps)))
|
||||
|
||||
def _concatenate_array(arr, axis: int, dtype=None):
|
||||
def _concatenate_array(arr, axis: Optional[int], dtype=None):
|
||||
# Fast path for concatenation when the input is an ndarray rather than a list.
|
||||
arr = asarray(arr, dtype=dtype)
|
||||
if arr.ndim == 0 or arr.shape[0] == 0:
|
||||
@ -1668,7 +1668,7 @@ def _concatenate_array(arr, axis: int, dtype=None):
|
||||
return lax.reshape(arr, shape, dimensions)
|
||||
|
||||
@_wraps(np.concatenate)
|
||||
def concatenate(arrays, axis: int = 0, dtype=None):
|
||||
def concatenate(arrays, axis: Optional[int] = 0, dtype=None):
|
||||
if isinstance(arrays, (np.ndarray, ndarray)):
|
||||
return _concatenate_array(arrays, axis, dtype=dtype)
|
||||
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
|
||||
|
@ -15,7 +15,7 @@
|
||||
from functools import partial
|
||||
import operator
|
||||
from textwrap import dedent as _dedent
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -24,6 +24,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
empty, full_like, isnan, lexsort, moveaxis, nonzero, ones, ravel,
|
||||
sort, where, zeros)
|
||||
from jax._src.numpy.util import _check_arraylike, _wraps
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import prod as _prod
|
||||
from jax import core
|
||||
from jax import jit
|
||||
@ -37,25 +38,28 @@ _lax_const = lax_internal._const
|
||||
@_wraps(np.in1d, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
@partial(jit, static_argnames=('assume_unique', 'invert',))
|
||||
def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811
|
||||
def in1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, invert: bool = False) -> Array:
|
||||
del assume_unique # unused
|
||||
return _in1d(ar1, ar2, invert)
|
||||
|
||||
@partial(jit, static_argnames=('invert',))
|
||||
def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
|
||||
_check_arraylike("in1d", ar1, ar2)
|
||||
ar1 = ravel(ar1)
|
||||
ar2 = ravel(ar2)
|
||||
ar1_flat = ravel(ar1)
|
||||
ar2_flat = ravel(ar2)
|
||||
# Note: an algorithm based on searchsorted has better scaling, but in practice
|
||||
# is very slow on accelerators because it relies on lax control flow. If XLA
|
||||
# ever supports binary search natively, we should switch to this:
|
||||
# ar2 = jnp.sort(ar2)
|
||||
# ind = jnp.searchsorted(ar2, ar1)
|
||||
# ar2_flat = jnp.sort(ar2_flat)
|
||||
# ind = jnp.searchsorted(ar2_flat, ar1_flat)
|
||||
# if invert:
|
||||
# return ar1 != ar2[ind]
|
||||
# return ar1_flat != ar2_flat[ind]
|
||||
# else:
|
||||
# return ar1 == ar2[ind]
|
||||
# return ar1_flat == ar2_flat[ind]
|
||||
if invert:
|
||||
return (ar1[:, None] != ar2[None, :]).all(-1)
|
||||
return (ar1_flat[:, None] != ar2_flat[None, :]).all(-1)
|
||||
else:
|
||||
return (ar1[:, None] == ar2[None, :]).any(-1)
|
||||
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)
|
||||
|
||||
@_wraps(np.setdiff1d,
|
||||
lax_description=_dedent("""
|
||||
@ -70,27 +74,28 @@ def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811
|
||||
fill_value : array_like, optional
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
|
||||
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
|
||||
def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array:
|
||||
_check_arraylike("setdiff1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()")
|
||||
ar1 = asarray(ar1)
|
||||
fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype)
|
||||
if ar1.size == 0:
|
||||
return full_like(ar1, fill_value, shape=size or 0)
|
||||
arr1 = asarray(ar1)
|
||||
fill_value = asarray(0 if fill_value is None else fill_value, dtype=arr1.dtype)
|
||||
if arr1.size == 0:
|
||||
return full_like(arr1, fill_value, shape=size or 0)
|
||||
if not assume_unique:
|
||||
ar1 = unique(ar1, size=size and ar1.size)
|
||||
mask = in1d(ar1, ar2, invert=True)
|
||||
arr1 = unique(arr1, size=size and arr1.size)
|
||||
mask = in1d(arr1, ar2, invert=True)
|
||||
if size is None:
|
||||
return ar1[mask]
|
||||
return arr1[mask]
|
||||
else:
|
||||
if not (assume_unique or size is None):
|
||||
# Set mask to zero at locations corresponding to unique() padding.
|
||||
n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum()
|
||||
mask = where(arange(ar1.size) < n_unique, mask, False)
|
||||
return where(arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
|
||||
n_unique = arr1.size + 1 - (arr1 == arr1[0]).sum()
|
||||
mask = where(arange(arr1.size) < n_unique, mask, False)
|
||||
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
|
||||
|
||||
|
||||
@_wraps(np.union1d,
|
||||
@ -107,7 +112,8 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``, which defaults to the minimum
|
||||
value of the union."""))
|
||||
def union1d(ar1, ar2, *, size=None, fill_value=None):
|
||||
def union1d(ar1: ArrayLike, ar2: ArrayLike,
|
||||
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array:
|
||||
_check_arraylike("union1d", ar1, ar2)
|
||||
if size is None:
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
|
||||
@ -121,7 +127,7 @@ def union1d(ar1, ar2, *, size=None, fill_value=None):
|
||||
In the JAX version, the input arrays are explicitly flattened regardless
|
||||
of assume_unique value.
|
||||
""")
|
||||
def setxor1d(ar1, ar2, assume_unique=False):
|
||||
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array:
|
||||
_check_arraylike("setxor1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
|
||||
@ -142,8 +148,8 @@ def setxor1d(ar1, ar2, assume_unique=False):
|
||||
return aux[flag[1:] & flag[:-1]]
|
||||
|
||||
|
||||
@partial(jit, static_argnums=2)
|
||||
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
|
||||
@partial(jit, static_argnames=['return_indices'])
|
||||
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> Tuple[Array, ...]:
|
||||
"""
|
||||
Helper function for intersect1d which is jit-able
|
||||
"""
|
||||
@ -162,7 +168,8 @@ def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
|
||||
|
||||
|
||||
@_wraps(np.intersect1d)
|
||||
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
||||
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
|
||||
return_indices: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
|
||||
_check_arraylike("intersect1d", ar1, ar2)
|
||||
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
|
||||
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
|
||||
@ -187,7 +194,7 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
||||
|
||||
if return_indices:
|
||||
ar1_indices = aux_sort_indices[:-1][mask]
|
||||
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
|
||||
ar2_indices = aux_sort_indices[1:][mask] - np.size(ar1)
|
||||
if not assume_unique:
|
||||
ar1_indices = ind1[ar1_indices]
|
||||
ar2_indices = ind2[ar2_indices]
|
||||
@ -200,7 +207,8 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
||||
@_wraps(np.isin, lax_description="""
|
||||
In the JAX version, the `assume_unique` argument is not referenced.
|
||||
""")
|
||||
def isin(element, test_elements, assume_unique=False, invert=False): # noqa: F811
|
||||
def isin(element: ArrayLike, test_elements: ArrayLike,
|
||||
assume_unique: bool = False, invert: bool = False) -> Array:
|
||||
result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert)
|
||||
return result.reshape(np.shape(element))
|
||||
|
||||
@ -212,7 +220,7 @@ UNIQUE_SIZE_HINT = (
|
||||
"a concrete value for the size argument, which will determine the output size.")
|
||||
|
||||
@partial(jit, static_argnums=1)
|
||||
def _unique_sorted_mask(ar, axis):
|
||||
def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
|
||||
aux = moveaxis(ar, axis, 0)
|
||||
if np.issubdtype(aux.dtype, np.complexfloating):
|
||||
# Work around issue in sorting of complex numbers with Nan only in the
|
||||
@ -238,8 +246,10 @@ def _unique_sorted_mask(ar, axis):
|
||||
mask = zeros(size, dtype=bool)
|
||||
return aux, mask, perm
|
||||
|
||||
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False,
|
||||
size=None, fill_value=None, return_true_size=False):
|
||||
def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bool = False,
|
||||
return_counts: bool = False, size: Optional[int] = None,
|
||||
fill_value: Optional[ArrayLike] = None, return_true_size: bool = False
|
||||
) -> Union[Array, Tuple[Array, ...]]:
|
||||
"""
|
||||
Find the unique elements of an array along a particular axis.
|
||||
"""
|
||||
@ -264,7 +274,7 @@ def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=Fa
|
||||
result = full_like(result, fill_value, shape=(size, *result.shape[1:]))
|
||||
result = moveaxis(result, 0, axis)
|
||||
|
||||
ret = (result,)
|
||||
ret: Tuple[Array, ...] = (result,)
|
||||
if return_index:
|
||||
if aux.size:
|
||||
ret += (perm[ind],)
|
||||
@ -309,8 +319,9 @@ def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=Fa
|
||||
When ``size`` is specified and there are fewer than the indicated number of elements, the
|
||||
remaining elements will be filled with ``fill_value``. The default is the minimum value
|
||||
along the specified axis of the input."""))
|
||||
def unique(ar, return_index=False, return_inverse=False,
|
||||
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
|
||||
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
|
||||
return_counts: bool = False, axis: Optional[int] = None,
|
||||
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None):
|
||||
_check_arraylike("unique", ar)
|
||||
if size is None:
|
||||
ar = core.concrete_or_error(None, ar,
|
||||
@ -318,9 +329,10 @@ def unique(ar, return_index=False, return_inverse=False,
|
||||
else:
|
||||
size = core.concrete_or_error(operator.index, size,
|
||||
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
|
||||
ar = asarray(ar)
|
||||
arr = asarray(ar)
|
||||
if axis is None:
|
||||
axis = 0
|
||||
ar = ar.flatten()
|
||||
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
|
||||
return _unique(ar, axis, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
|
||||
arr = arr.flatten()
|
||||
axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
|
||||
return _unique(arr, axis_int, return_index, return_inverse,
|
||||
return_counts, size=size, fill_value=fill_value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user