Merge pull request #13010 from jakevdp:annotate-lax-numpy

PiperOrigin-RevId: 484572239
This commit is contained in:
jax authors 2022-10-28 10:57:48 -07:00
commit f9e7629c3f
2 changed files with 154 additions and 120 deletions

View File

@ -47,7 +47,7 @@ class _IndexGrid(abc.ABC):
sparse: bool
op_name: str
def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Array:
def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Union[Array, List[Array]]:
if isinstance(key, slice):
return _make_1d_grid_from_slice(key, op_name=self.op_name)
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key)

View File

@ -1938,7 +1938,7 @@ https://jax.readthedocs.io/en/latest/faq.html).
@_wraps(np.array, lax_description=_ARRAY_DOC)
def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
order: str = "K", ndmin: int = 0) -> Array:
order: Optional[str] = "K", ndmin: int = 0) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
@ -2012,7 +2012,7 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
return out_array
def _convert_to_array_if_dtype_fails(x):
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
try:
dtypes.dtype(x)
except TypeError:
@ -2022,14 +2022,14 @@ def _convert_to_array_if_dtype_fails(x):
@_wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order) # type: ignore
@_wraps(np.copy, lax_description=_ARRAY_DOC)
def copy(a, order=None):
def copy(a: Any, order: Optional[str] = None) -> Array:
return array(a, copy=True, order=order)
@ -2116,13 +2116,13 @@ def empty(shape: Any, dtype: Optional[DTypeLike] = None) -> Array:
@_wraps(np.array_equal)
def array_equal(a1, a2, equal_nan=False):
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array:
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
return bool_(False)
if shape(a1) != shape(a2):
return False
return bool_(False)
eq = asarray(a1 == a2)
if equal_nan:
eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
@ -2130,16 +2130,16 @@ def array_equal(a1, a2, equal_nan=False):
@_wraps(np.array_equiv)
def array_equiv(a1, a2):
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
return bool_(False)
try:
eq = equal(a1, a2)
except ValueError:
# shapes are not broadcastable
return False
return bool_(False)
return all(eq)
@ -2207,7 +2207,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s
@_wraps(np.eye)
def eye(N: core.DimSize, M: Optional[core.DimSize] = None, k: int = 0,
def eye(N: DimSize, M: Optional[DimSize] = None, k: int = 0,
dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "eye")
N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()")
@ -2219,14 +2219,14 @@ def eye(N: core.DimSize, M: Optional[core.DimSize] = None, k: int = 0,
@_wraps(np.identity)
def identity(n: core.DimSize, dtype: Optional[DTypeLike] = None) -> Array:
def identity(n: DimSize, dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "identity")
return eye(n, dtype=dtype)
@_wraps(np.arange)
def arange(start: core.DimSize, stop: Optional[core.DimSize] = None,
step: Optional[core.DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array:
def arange(start: DimSize, stop: Optional[DimSize] = None,
step: Optional[DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "arange")
require = partial(core.concrete_or_error, None)
msg = "It arose in jax.numpy.arange argument `{}`.".format
@ -2404,7 +2404,8 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool
@_wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
indexing: str = 'xy') -> List[Array]:
_check_arraylike("meshgrid", *xi)
args = [asarray(x) for x in xi]
if not copy:
@ -2426,17 +2427,16 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
@_wraps(np.i0)
@jit
def i0(x):
x_orig = x
x, = _promote_args_inexact("i0", x)
if not issubdtype(x.dtype, np.floating):
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x_orig)}")
x = lax.abs(x)
return lax.mul(lax.exp(x), lax.bessel_i0e(x))
def i0(x: ArrayLike) -> Array:
x_arr, = _promote_args_inexact("i0", x)
if not issubdtype(x_arr.dtype, np.floating):
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}")
x_arr = lax.abs(x_arr)
return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr))
@_wraps(np.ix_)
def ix_(*args):
def ix_(*args: ArrayLike) -> Tuple[Array, ...]:
_check_arraylike("ix", *args)
n = len(args)
output = []
@ -2458,8 +2458,18 @@ def ix_(*args):
return tuple(output)
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: Literal[False] = False) -> Array: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
*, sparse: Literal[True]) -> Tuple[Array, ...]: ...
@overload
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: ...
@_wraps(np.indices)
def indices(dimensions, dtype=int32, sparse=False):
def indices(dimensions: Sequence[int], dtype: DTypeLike = int32,
sparse: bool = False) -> Union[Array, Tuple[Array, ...]]:
dimensions = tuple(
core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices")
for d in dimensions)
@ -2487,13 +2497,16 @@ will be repeated.
@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *,
total_repeat_length: Optional[int] = None) -> Array:
_check_arraylike("repeat", a)
core.is_special_dim_size(repeats) or _check_arraylike("repeat", repeats)
if axis is None:
a = ravel(a)
axis = 0
else:
a = asarray(a)
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()")
assert isinstance(axis, int) # to appease mypy
@ -2512,28 +2525,28 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
# Fast path for when repeats is a scalar.
if np.ndim(repeats) == 0 and ndim(a) != 0:
input_shape = a.shape
input_shape = shape(a)
aux_axis = axis if axis < 0 else axis + 1
a = expand_dims(a, aux_axis)
reps = [1] * len(a.shape)
reps: List[DimSize] = [1] * len(shape(a))
reps[aux_axis] = repeats
a = tile(a, reps)
result_shape = list(input_shape)
result_shape: List[DimSize] = list(input_shape)
result_shape[axis] *= repeats
return reshape(a, result_shape)
repeats = np.ravel(repeats)
if ndim(a) != 0:
repeats = np.broadcast_to(repeats, [a.shape[axis]])
repeats = np.broadcast_to(repeats, [shape(a)[axis]])
total_repeat_length = np.sum(repeats)
else:
repeats = ravel(repeats)
if ndim(a) != 0:
repeats = broadcast_to(repeats, [a.shape[axis]])
repeats = broadcast_to(repeats, [shape(a)[axis]])
# Special case when a is a scalar.
if ndim(a) == 0:
if repeats.shape == (1,):
if shape(repeats) == (1,):
return full([total_repeat_length], a)
else:
raise ValueError('`repeat` with a scalar parameter `a` is only '
@ -2541,13 +2554,13 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
# Special case if total_repeat_length is zero.
if total_repeat_length == 0:
result_shape = list(a.shape)
result_shape = list(shape(a))
result_shape[axis] = 0
return reshape(array([], dtype=a.dtype), result_shape)
return reshape(array([], dtype=_dtype(a)), result_shape)
# If repeats is on a zero sized axis, then return the array.
if a.shape[axis] == 0:
return a
if shape(a)[axis] == 0:
return asarray(a)
# This implementation of repeat avoid having to instantiate a large.
# intermediate tensor.
@ -2565,7 +2578,7 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
@_wraps(np.tri)
def tri(N, M=None, k=0, dtype=None):
def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: DTypeLike = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
@ -2574,29 +2587,32 @@ def tri(N, M=None, k=0, dtype=None):
@_wraps(np.tril)
@partial(jit, static_argnames=('k',))
def tril(m, k=0):
def tril(m: ArrayLike, k: int = 0) -> Array:
_check_arraylike("tril", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
mask = tri(*m_shape[-2:], k=k, dtype=bool)
N, M = m_shape[-2:]
mask = tri(N, M, k=k, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
@_wraps(np.triu, update_doc=False)
@partial(jit, static_argnames=('k',))
def triu(m, k=0):
def triu(m: ArrayLike, k: int = 0) -> Array:
_check_arraylike("triu", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
N, M = m_shape[-2:]
mask = tri(N, M, k=k - 1, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
@_wraps(np.trace, skip_params=['out'])
@partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype'))
def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
dtype: Optional[DTypeLike] = None, out: None = None) -> Array:
_check_arraylike("trace", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
@ -2636,13 +2652,15 @@ mask_indices = _wrap_indices_function(np.mask_indices)
@_wraps(np.triu_indices_from)
def triu_indices_from(arr, k=0):
return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])
def triu_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array]:
arr_shape = shape(arr)
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
@_wraps(np.tril_indices_from)
def tril_indices_from(arr, k=0):
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
def tril_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array]:
arr_shape = shape(arr)
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
@_wraps(np.diag_indices)
@ -4250,7 +4268,7 @@ def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
idx = tuple(idx) + colons
return idx
def _static_idx(idx: slice, size: core.DimSize):
def _static_idx(idx: slice, size: DimSize):
"""Helper function to compute the static slice start/limit/stride values."""
if isinstance(size, int):
start, stop, step = idx.indices(size)
@ -4318,11 +4336,11 @@ def kaiser(M: int, beta: ArrayLike) -> Array:
return i0(beta * sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta)
def _gcd_cond_fn(xs):
def _gcd_cond_fn(xs: Tuple[Array, Array]) -> Array:
x1, x2 = xs
return any(x2 != 0)
def _gcd_body_fn(xs):
def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]:
x1, x2 = xs
x1, x2 = (where(x2 != 0, x2, x1),
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
@ -4330,7 +4348,7 @@ def _gcd_body_fn(xs):
@_wraps(np.gcd, module='numpy')
@jit
def gcd(x1, x2):
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
_check_arraylike("gcd", x1, x2)
x1, x2 = _promote_dtypes(x1, x2)
if not issubdtype(_dtype(x1), integer):
@ -4342,7 +4360,7 @@ def gcd(x1, x2):
@_wraps(np.lcm, module='numpy')
@jit
def lcm(x1, x2):
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
_check_arraylike("lcm", x1, x2)
x1, x2 = _promote_dtypes(x1, x2)
if not issubdtype(_dtype(x1), integer):
@ -4353,34 +4371,37 @@ def lcm(x1, x2):
@_wraps(np.extract)
def extract(condition, arr):
def extract(condition: ArrayLike, arr: ArrayLike) -> Array:
return compress(ravel(condition), ravel(arr))
@_wraps(np.compress, skip_params=['out'])
def compress(condition, a, axis: Optional[int] = None, out=None):
def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = None,
out: None = None) -> Array:
_check_arraylike("compress", condition, a)
condition_arr = asarray(condition).astype(bool)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
if ndim(condition) != 1:
if condition_arr.ndim != 1:
raise ValueError("condition must be a 1D array")
condition = asarray(condition).astype(bool)
if axis is None:
axis = 0
a = ravel(a)
arr = ravel(a)
else:
a = moveaxis(a, axis, 0)
condition, extra = condition[:a.shape[0]], condition[a.shape[0]:]
arr = moveaxis(a, axis, 0)
condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:]
if any(extra):
raise ValueError("condition contains entries that are out of bounds")
a = a[:condition.shape[0]]
return moveaxis(a[condition], 0, axis)
arr = arr[:condition_arr.shape[0]]
return moveaxis(arr[condition_arr], 0, axis)
@_wraps(np.cov)
@partial(jit, static_argnames=('rowvar', 'bias', 'ddof'))
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
aweights=None):
def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True,
bias: bool = False, ddof: Optional[int] = None,
fweights: Optional[ArrayLike] = None,
aweights: Optional[ArrayLike] = None) -> Array:
if y is not None:
m, y = _promote_args_inexact("cov", m, y)
if y.ndim > 2:
@ -4398,14 +4419,14 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
return array([]).reshape(0, 0)
if y is not None:
y = atleast_2d(y)
if not rowvar and y.shape[0] != 1:
y = y.T
X = concatenate((X, y), axis=0)
y_arr = atleast_2d(y)
if not rowvar and y_arr.shape[0] != 1:
y_arr = y_arr.T
X = concatenate((X, y_arr), axis=0)
if ddof is None:
ddof = 1 if bias == 0 else 0
w = None
w: Optional[Array] = None
if fweights is not None:
_check_arraylike("cov", fweights)
if ndim(fweights) > 1:
@ -4424,7 +4445,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
raise RuntimeError("incompatible numbers of samples and aweights")
# Ensure positive aweights: note that numpy raises an error for negative aweights.
aweights = abs(aweights)
w = aweights if w is None else w * aweights
w = asarray(aweights) if w is None else w * asarray(aweights)
avg, w_sum = average(X, axis=1, weights=w, returned=True)
w_sum = w_sum[0]
@ -4445,7 +4466,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
@_wraps(np.corrcoef)
@partial(jit, static_argnames=('rowvar',))
def corrcoef(x, y=None, rowvar=True):
def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -> Array:
_check_arraylike("corrcoef", x)
c = cov(x, y, rowvar)
if len(shape(c)) == 0:
@ -4467,9 +4488,9 @@ def corrcoef(x, y=None, rowvar=True):
@_wraps(np.quantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, method="linear", keepdims=False,
interpolation=None):
def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("quantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.quantile does not support overwrite_input=True or "
@ -4478,14 +4499,14 @@ def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
if interpolation is not None:
warnings.warn("The interpolation= argument to 'quantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(a, q, axis, interpolation or method, keepdims, False)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False)
@_wraps(np.nanquantile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, method="linear",
keepdims=False, interpolation=None):
def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("nanquantile", a, q)
if overwrite_input or out is not None:
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
@ -4494,9 +4515,10 @@ def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
if interpolation is not None:
warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. "
"Use 'method=' instead.", DeprecationWarning)
return _quantile(a, q, axis, interpolation or method, keepdims, True)
return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True)
def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
interpolation: str, keepdims: bool, squash_nans: bool) -> Array:
if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
"'midpoint', or 'nearest'")
@ -4524,7 +4546,6 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
keepdim = tuple(keepdim)
axis = _canonicalize_axis(-1, ndim(a))
else:
axis = _canonicalize_axis(axis, ndim(a))
@ -4614,13 +4635,13 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
raise ValueError(f"interpolation={interpolation!r} not recognized")
if keepdims and keepdim:
if q_ndim > 0:
keepdim = (shape(q)[0],) + keepdim
keepdim = [shape(q)[0], *keepdim]
result = reshape(result, keepdim)
return lax.convert_element_type(result, a.dtype)
@partial(vectorize, excluded={0, 2, 3})
def _searchsorted_via_scan(sorted_arr, query, side, dtype):
def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
op = _sort_le_comparator if side == 'left' else _sort_lt_comparator
def body_fun(_, state):
low, high = state
@ -4632,7 +4653,7 @@ def _searchsorted_via_scan(sorted_arr, query, side, dtype):
return lax.fori_loop(0, n_levels, body_fun, init)[1]
def _searchsorted_via_sort(sorted_arr, query, side, dtype):
def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array:
working_dtype = int32 if sorted_arr.size + query.size < np.iinfo(np.int32).max else int64
def _rank(x):
idx = lax.iota(working_dtype, len(x))
@ -4653,7 +4674,8 @@ def _searchsorted_via_sort(sorted_arr, query, side, dtype):
'sort' is often more performant on accelerator backends like GPU and TPU (particularly
when ``v`` is very large)."""))
@partial(jit, static_argnames=('side', 'sorter', 'method'))
def searchsorted(a, v, side='left', sorter=None, *, method='scan'):
def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left',
sorter: None = None, *, method: str = 'scan') -> Array:
_check_arraylike("searchsorted", a, v)
if side not in ['left', 'right']:
raise ValueError(f"{side!r} is an invalid value for keyword 'side'. "
@ -4670,22 +4692,23 @@ def searchsorted(a, v, side='left', sorter=None, *, method='scan'):
if len(a) == 0:
return zeros_like(v, dtype=dtype)
impl = _searchsorted_via_scan if method == 'scan' else _searchsorted_via_sort
return impl(a, v, side, dtype)
return impl(asarray(a), asarray(v), side, dtype)
@_wraps(np.digitize)
@partial(jit, static_argnames=('right',))
def digitize(x, bins, right=False):
def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array:
_check_arraylike("digitize", x, bins)
right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()")
if ndim(bins) != 1:
bins_arr = asarray(bins)
if bins_arr.ndim != 1:
raise ValueError(f"digitize: bins must be a 1-dimensional array; got bins={bins}")
if len(bins) == 0:
if bins_arr.shape[0] == 0:
return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
side = 'right' if not right else 'left'
return where(
bins[-1] >= bins[0],
searchsorted(bins, x, side=side),
len(bins) - searchsorted(bins[::-1], x, side=side)
bins_arr[-1] >= bins_arr[0],
searchsorted(bins_arr, x, side=side),
len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side)
)
_PIECEWISE_DOC = """\
@ -4695,9 +4718,10 @@ See the :func:`jax.lax.switch` documentation for more information.
"""
@_wraps(np.piecewise, lax_description=_PIECEWISE_DOC)
def piecewise(x, condlist, funclist, *args, **kw):
def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
funclist: List[Union[ArrayLike, Callable[..., Array]]],
*args, **kw) -> Array:
_check_arraylike("piecewise", x)
condlist = array(condlist, dtype=bool_)
nc, nf = len(condlist), len(funclist)
if nf == nc + 1:
funclist = funclist[-1:] + funclist[:-1]
@ -4707,14 +4731,16 @@ def piecewise(x, condlist, funclist, *args, **kw):
raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}")
consts = {i: c for i, c in enumerate(funclist) if not callable(c)}
funcs = {i: f for i, f in enumerate(funclist) if callable(f)}
return _piecewise(x, condlist, consts,
return _piecewise(asarray(x), asarray(condlist, dtype=bool_), consts,
frozenset(funcs.items()), # dict is not hashable.
*args, **kw)
@partial(jit, static_argnames=['funcs'])
def _piecewise(x, condlist, consts, funcs, *args, **kw):
funcs = dict(funcs)
funclist = [consts.get(i, funcs.get(i)) for i in range(len(condlist) + 1)]
def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike],
funcs: FrozenSet[Tuple[int, Callable[..., Array]]],
*args, **kw) -> Array:
funcdict = dict(funcs)
funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)]
indices = argmax(cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0)
dtype = _dtype(x)
def _call(f):
@ -4728,9 +4754,10 @@ def _piecewise(x, condlist, consts, funcs, *args, **kw):
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, method="linear",
keepdims=False, interpolation=None):
def percentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("percentile", a, q)
q, = _promote_dtypes_inexact(q)
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
@ -4739,9 +4766,10 @@ def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
@_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims', 'method'))
def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, method="linear",
keepdims=False, interpolation=None):
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, interpolation: None = None) -> Array:
_check_arraylike("nanpercentile", a, q)
q = true_divide(q, float32(100.0))
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
@ -4750,23 +4778,25 @@ def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
@_wraps(np.median, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, keepdims=False):
def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
_check_arraylike("median", a)
return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
keepdims=keepdims, method='midpoint')
@_wraps(np.nanmedian, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims'))
def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
overwrite_input=False, keepdims=False):
def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out: None = None, overwrite_input: bool = False,
keepdims: bool = False) -> Array:
_check_arraylike("nanmedian", a)
return nanquantile(a, 0.5, axis=axis, out=out,
overwrite_input=overwrite_input, keepdims=keepdims,
method='midpoint')
def _astype(arr, dtype):
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
"""Copy the array and cast to a specified dtype.
This is implemeted via :func:`jax.lax.convert_element_type`, which may
@ -4780,19 +4810,21 @@ def _astype(arr, dtype):
return lax.convert_element_type(arr, dtype)
def _nbytes(arr):
def _nbytes(arr: ArrayLike) -> int:
return size(arr) * _dtype(arr).itemsize
def _itemsize(arr):
def _itemsize(arr: ArrayLike) -> int:
return _dtype(arr).itemsize
def _clip(number, min=None, max=None, out=None): # noqa: F811
def _clip(number: ArrayLike,
min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, # noqa: F811
out: None = None) -> Array:
return clip(number, a_min=min, a_max=max, out=out)
def _view(arr, dtype=None, type=None):
def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "view")
if type is not None:
raise NotImplementedError("`type` argument of array.view()")
@ -4814,7 +4846,8 @@ def _view(arr, dtype=None, type=None):
if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
raise ValueError("When changing to a larger dtype, its size must be a divisor "
"of the total size in bytes of the last axis of the array.")
byte_dtypes = {8: uint8, 16: uint16, 32: uint32, 64: uint64}
byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'),
32: np.dtype('uint32'), 64: np.dtype('uint64')}
if nbits_in not in byte_dtypes:
raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
if nbits_out not in byte_dtypes:
@ -4883,15 +4916,15 @@ def _unimplemented_setitem(self, i, x):
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html")
raise TypeError(msg.format(type(self)))
def _operator_round(number, ndigits=None):
def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array:
out = round(number, decimals=ndigits or 0)
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
return out.astype(int) if ndigits is None else out
def _copy(self):
def _copy(self: Array) -> Array:
return self.copy()
def _deepcopy(self, memo):
def _deepcopy(self: Array, memo: Any) -> Array:
del memo # unused
return self.copy()
@ -4970,22 +5003,23 @@ def __array_module__(self, types):
return NotImplemented
def _compress_method(a, condition, axis=None, out=None):
def _compress_method(a: ArrayLike, condition: ArrayLike,
axis: Optional[int] = None, out: None = None) -> Array:
return compress(condition, a, axis, out)
@core.stash_axis_env()
@partial(jit, static_argnums=(1,2,3))
def _multi_slice(arr,
def _multi_slice(arr: ArrayLike,
start_indices: Tuple[Tuple[int, ...]],
limit_indices: Tuple[Tuple[int, ...]],
removed_dims: Tuple[Tuple[int, ...]]):
removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]:
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
DeviceArray method here to avoid circular imports.
"""
results = []
results: List[Array] = []
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
sliced = lax.slice(arr, starts, limits)
if removed:
@ -4996,7 +5030,7 @@ def _multi_slice(arr,
# The next two functions are related to iter(device_array), implemented here to
# avoid circular imports.
@jit
def _unstack(x):
def _unstack(x: Array) -> List[Array]:
return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]
setattr(device_array.DeviceArray, "_unstack", _unstack)
setattr(ArrayImpl, '_unstack', _unstack)