[typing] add full annotations for lax_numpy setops

This commit is contained in:
Jake VanderPlas 2022-10-03 12:52:28 -07:00
parent be56a3559d
commit fd45035b21
2 changed files with 53 additions and 41 deletions

View File

@ -1653,7 +1653,7 @@ def tile(A, reps):
[k for pair in zip(reps, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps)))
def _concatenate_array(arr, axis: int, dtype=None):
def _concatenate_array(arr, axis: Optional[int], dtype=None):
# Fast path for concatenation when the input is an ndarray rather than a list.
arr = asarray(arr, dtype=dtype)
if arr.ndim == 0 or arr.shape[0] == 0:
@ -1668,7 +1668,7 @@ def _concatenate_array(arr, axis: int, dtype=None):
return lax.reshape(arr, shape, dimensions)
@_wraps(np.concatenate)
def concatenate(arrays, axis: int = 0, dtype=None):
def concatenate(arrays, axis: Optional[int] = 0, dtype=None):
if isinstance(arrays, (np.ndarray, ndarray)):
return _concatenate_array(arrays, axis, dtype=dtype)
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)

View File

@ -15,7 +15,7 @@
from functools import partial
import operator
from textwrap import dedent as _dedent
from typing import Optional
from typing import Optional, Tuple, Union
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
@ -24,6 +24,7 @@ from jax._src.numpy.lax_numpy import (
empty, full_like, isnan, lexsort, moveaxis, nonzero, ones, ravel,
sort, where, zeros)
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.typing import Array, ArrayLike
from jax._src.util import prod as _prod
from jax import core
from jax import jit
@ -37,25 +38,28 @@ _lax_const = lax_internal._const
@_wraps(np.in1d, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
""")
@partial(jit, static_argnames=('assume_unique', 'invert',))
def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811
def in1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, invert: bool = False) -> Array:
del assume_unique # unused
return _in1d(ar1, ar2, invert)
@partial(jit, static_argnames=('invert',))
def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
_check_arraylike("in1d", ar1, ar2)
ar1 = ravel(ar1)
ar2 = ravel(ar2)
ar1_flat = ravel(ar1)
ar2_flat = ravel(ar2)
# Note: an algorithm based on searchsorted has better scaling, but in practice
# is very slow on accelerators because it relies on lax control flow. If XLA
# ever supports binary search natively, we should switch to this:
# ar2 = jnp.sort(ar2)
# ind = jnp.searchsorted(ar2, ar1)
# ar2_flat = jnp.sort(ar2_flat)
# ind = jnp.searchsorted(ar2_flat, ar1_flat)
# if invert:
# return ar1 != ar2[ind]
# return ar1_flat != ar2_flat[ind]
# else:
# return ar1 == ar2[ind]
# return ar1_flat == ar2_flat[ind]
if invert:
return (ar1[:, None] != ar2[None, :]).all(-1)
return (ar1_flat[:, None] != ar2_flat[None, :]).all(-1)
else:
return (ar1[:, None] == ar2[None, :]).any(-1)
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)
@_wraps(np.setdiff1d,
lax_description=_dedent("""
@ -70,27 +74,28 @@ def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array:
_check_arraylike("setdiff1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
else:
size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()")
ar1 = asarray(ar1)
fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype)
if ar1.size == 0:
return full_like(ar1, fill_value, shape=size or 0)
arr1 = asarray(ar1)
fill_value = asarray(0 if fill_value is None else fill_value, dtype=arr1.dtype)
if arr1.size == 0:
return full_like(arr1, fill_value, shape=size or 0)
if not assume_unique:
ar1 = unique(ar1, size=size and ar1.size)
mask = in1d(ar1, ar2, invert=True)
arr1 = unique(arr1, size=size and arr1.size)
mask = in1d(arr1, ar2, invert=True)
if size is None:
return ar1[mask]
return arr1[mask]
else:
if not (assume_unique or size is None):
# Set mask to zero at locations corresponding to unique() padding.
n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum()
mask = where(arange(ar1.size) < n_unique, mask, False)
return where(arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
n_unique = arr1.size + 1 - (arr1 == arr1[0]).sum()
mask = where(arange(arr1.size) < n_unique, mask, False)
return where(arange(size) < mask.sum(), arr1[where(mask, size=size)], fill_value)
@_wraps(np.union1d,
@ -107,7 +112,8 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to the minimum
value of the union."""))
def union1d(ar1, ar2, *, size=None, fill_value=None):
def union1d(ar1: ArrayLike, ar2: ArrayLike,
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None) -> Array:
_check_arraylike("union1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
@ -121,7 +127,7 @@ def union1d(ar1, ar2, *, size=None, fill_value=None):
In the JAX version, the input arrays are explicitly flattened regardless
of assume_unique value.
""")
def setxor1d(ar1, ar2, assume_unique=False):
def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array:
_check_arraylike("setxor1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
@ -142,8 +148,8 @@ def setxor1d(ar1, ar2, assume_unique=False):
return aux[flag[1:] & flag[:-1]]
@partial(jit, static_argnums=2)
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
@partial(jit, static_argnames=['return_indices'])
def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> Tuple[Array, ...]:
"""
Helper function for intersect1d which is jit-able
"""
@ -162,7 +168,8 @@ def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
@_wraps(np.intersect1d)
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
return_indices: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
_check_arraylike("intersect1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
@ -187,7 +194,7 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
if return_indices:
ar1_indices = aux_sort_indices[:-1][mask]
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
ar2_indices = aux_sort_indices[1:][mask] - np.size(ar1)
if not assume_unique:
ar1_indices = ind1[ar1_indices]
ar2_indices = ind2[ar2_indices]
@ -200,7 +207,8 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
@_wraps(np.isin, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
""")
def isin(element, test_elements, assume_unique=False, invert=False): # noqa: F811
def isin(element: ArrayLike, test_elements: ArrayLike,
assume_unique: bool = False, invert: bool = False) -> Array:
result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert)
return result.reshape(np.shape(element))
@ -212,7 +220,7 @@ UNIQUE_SIZE_HINT = (
"a concrete value for the size argument, which will determine the output size.")
@partial(jit, static_argnums=1)
def _unique_sorted_mask(ar, axis):
def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
aux = moveaxis(ar, axis, 0)
if np.issubdtype(aux.dtype, np.complexfloating):
# Work around issue in sorting of complex numbers with Nan only in the
@ -238,8 +246,10 @@ def _unique_sorted_mask(ar, axis):
mask = zeros(size, dtype=bool)
return aux, mask, perm
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False,
size=None, fill_value=None, return_true_size=False):
def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, size: Optional[int] = None,
fill_value: Optional[ArrayLike] = None, return_true_size: bool = False
) -> Union[Array, Tuple[Array, ...]]:
"""
Find the unique elements of an array along a particular axis.
"""
@ -264,7 +274,7 @@ def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=Fa
result = full_like(result, fill_value, shape=(size, *result.shape[1:]))
result = moveaxis(result, 0, axis)
ret = (result,)
ret: Tuple[Array, ...] = (result,)
if return_index:
if aux.size:
ret += (perm[ind],)
@ -309,8 +319,9 @@ def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=Fa
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``. The default is the minimum value
along the specified axis of the input."""))
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, axis: Optional[int] = None,
*, size: Optional[int] = None, fill_value: Optional[ArrayLike] = None):
_check_arraylike("unique", ar)
if size is None:
ar = core.concrete_or_error(None, ar,
@ -318,9 +329,10 @@ def unique(ar, return_index=False, return_inverse=False,
else:
size = core.concrete_or_error(operator.index, size,
"The error arose for the size argument of jnp.unique(). " + UNIQUE_SIZE_HINT)
ar = asarray(ar)
arr = asarray(ar)
if axis is None:
axis = 0
ar = ar.flatten()
axis = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
return _unique(ar, axis, return_index, return_inverse, return_counts, size=size, fill_value=fill_value)
arr = arr.flatten()
axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
return _unique(arr, axis_int, return_index, return_inverse,
return_counts, size=size, fill_value=fill_value)