From 15b489fb6f5dfa5978662ab1914e58e7aa1751e6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 27 Oct 2022 12:38:38 -0700 Subject: [PATCH] [typing] annotate next section of lax_numpy.py --- jax/_src/numpy/index_tricks.py | 2 +- jax/_src/numpy/lax_numpy.py | 272 ++++++++++++++++++--------------- 2 files changed, 154 insertions(+), 120 deletions(-) diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index c3265dccc..c0affbb83 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -47,7 +47,7 @@ class _IndexGrid(abc.ABC): sparse: bool op_name: str - def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Array: + def __getitem__(self, key: Union[slice, Tuple[slice, ...]]) -> Union[Array, List[Array]]: if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name=self.op_name) output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index d10c528e6..94d3bda1e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1938,7 +1938,7 @@ https://jax.readthedocs.io/en/latest/faq.html). @_wraps(np.array, lax_description=_ARRAY_DOC) def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True, - order: str = "K", ndmin: int = 0) -> Array: + order: Optional[str] = "K", ndmin: int = 0) -> Array: if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'") @@ -2012,7 +2012,7 @@ def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True, return out_array -def _convert_to_array_if_dtype_fails(x): +def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: try: dtypes.dtype(x) except TypeError: @@ -2022,14 +2022,14 @@ def _convert_to_array_if_dtype_fails(x): @_wraps(np.asarray, lax_description=_ARRAY_DOC) -def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array: +def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "asarray") dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype return array(a, dtype=dtype, copy=False, order=order) # type: ignore @_wraps(np.copy, lax_description=_ARRAY_DOC) -def copy(a, order=None): +def copy(a: Any, order: Optional[str] = None) -> Array: return array(a, copy=True, order=order) @@ -2116,13 +2116,13 @@ def empty(shape: Any, dtype: Optional[DTypeLike] = None) -> Array: @_wraps(np.array_equal) -def array_equal(a1, a2, equal_nan=False): +def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan: bool = False) -> Array: try: a1, a2 = asarray(a1), asarray(a2) except Exception: - return False + return bool_(False) if shape(a1) != shape(a2): - return False + return bool_(False) eq = asarray(a1 == a2) if equal_nan: eq = logical_or(eq, logical_and(isnan(a1), isnan(a2))) @@ -2130,16 +2130,16 @@ def array_equal(a1, a2, equal_nan=False): @_wraps(np.array_equiv) -def array_equiv(a1, a2): +def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: try: a1, a2 = asarray(a1), asarray(a2) except Exception: - return False + return bool_(False) try: eq = equal(a1, a2) except ValueError: # shapes are not broadcastable - return False + return bool_(False) return all(eq) @@ -2207,7 +2207,7 @@ def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: s @_wraps(np.eye) -def eye(N: core.DimSize, M: Optional[core.DimSize] = None, k: int = 0, +def eye(N: DimSize, M: Optional[DimSize] = None, k: int = 0, dtype: Optional[DTypeLike] = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "eye") N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()") @@ -2219,14 +2219,14 @@ def eye(N: core.DimSize, M: Optional[core.DimSize] = None, k: int = 0, @_wraps(np.identity) -def identity(n: core.DimSize, dtype: Optional[DTypeLike] = None) -> Array: +def identity(n: DimSize, dtype: Optional[DTypeLike] = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "identity") return eye(n, dtype=dtype) @_wraps(np.arange) -def arange(start: core.DimSize, stop: Optional[core.DimSize] = None, - step: Optional[core.DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array: +def arange(start: DimSize, stop: Optional[DimSize] = None, + step: Optional[DimSize] = None, dtype: Optional[DTypeLike] = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "arange") require = partial(core.concrete_or_error, None) msg = "It arose in jax.numpy.arange argument `{}`.".format @@ -2404,7 +2404,8 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool @_wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) -def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): +def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, + indexing: str = 'xy') -> List[Array]: _check_arraylike("meshgrid", *xi) args = [asarray(x) for x in xi] if not copy: @@ -2426,17 +2427,16 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): @_wraps(np.i0) @jit -def i0(x): - x_orig = x - x, = _promote_args_inexact("i0", x) - if not issubdtype(x.dtype, np.floating): - raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x_orig)}") - x = lax.abs(x) - return lax.mul(lax.exp(x), lax.bessel_i0e(x)) +def i0(x: ArrayLike) -> Array: + x_arr, = _promote_args_inexact("i0", x) + if not issubdtype(x_arr.dtype, np.floating): + raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}") + x_arr = lax.abs(x_arr) + return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr)) @_wraps(np.ix_) -def ix_(*args): +def ix_(*args: ArrayLike) -> Tuple[Array, ...]: _check_arraylike("ix", *args) n = len(args) output = [] @@ -2458,8 +2458,18 @@ def ix_(*args): return tuple(output) +@overload +def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, + sparse: Literal[False] = False) -> Array: ... +@overload +def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, + *, sparse: Literal[True]) -> Tuple[Array, ...]: ... +@overload +def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, + sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: ... @_wraps(np.indices) -def indices(dimensions, dtype=int32, sparse=False): +def indices(dimensions: Sequence[int], dtype: DTypeLike = int32, + sparse: bool = False) -> Union[Array, Tuple[Array, ...]]: dimensions = tuple( core.concrete_or_error(operator.index, d, "dimensions argument of jnp.indices") for d in dimensions) @@ -2487,13 +2497,16 @@ will be repeated. @_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC) -def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None): +def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = None, *, + total_repeat_length: Optional[int] = None) -> Array: _check_arraylike("repeat", a) core.is_special_dim_size(repeats) or _check_arraylike("repeat", repeats) if axis is None: a = ravel(a) axis = 0 + else: + a = asarray(a) axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()") assert isinstance(axis, int) # to appease mypy @@ -2512,28 +2525,28 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None): # Fast path for when repeats is a scalar. if np.ndim(repeats) == 0 and ndim(a) != 0: - input_shape = a.shape + input_shape = shape(a) aux_axis = axis if axis < 0 else axis + 1 a = expand_dims(a, aux_axis) - reps = [1] * len(a.shape) + reps: List[DimSize] = [1] * len(shape(a)) reps[aux_axis] = repeats a = tile(a, reps) - result_shape = list(input_shape) + result_shape: List[DimSize] = list(input_shape) result_shape[axis] *= repeats return reshape(a, result_shape) repeats = np.ravel(repeats) if ndim(a) != 0: - repeats = np.broadcast_to(repeats, [a.shape[axis]]) + repeats = np.broadcast_to(repeats, [shape(a)[axis]]) total_repeat_length = np.sum(repeats) else: repeats = ravel(repeats) if ndim(a) != 0: - repeats = broadcast_to(repeats, [a.shape[axis]]) + repeats = broadcast_to(repeats, [shape(a)[axis]]) # Special case when a is a scalar. if ndim(a) == 0: - if repeats.shape == (1,): + if shape(repeats) == (1,): return full([total_repeat_length], a) else: raise ValueError('`repeat` with a scalar parameter `a` is only ' @@ -2541,13 +2554,13 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None): # Special case if total_repeat_length is zero. if total_repeat_length == 0: - result_shape = list(a.shape) + result_shape = list(shape(a)) result_shape[axis] = 0 - return reshape(array([], dtype=a.dtype), result_shape) + return reshape(array([], dtype=_dtype(a)), result_shape) # If repeats is on a zero sized axis, then return the array. - if a.shape[axis] == 0: - return a + if shape(a)[axis] == 0: + return asarray(a) # This implementation of repeat avoid having to instantiate a large. # intermediate tensor. @@ -2565,7 +2578,7 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None): @_wraps(np.tri) -def tri(N, M=None, k=0, dtype=None): +def tri(N: int, M: Optional[int] = None, k: int = 0, dtype: DTypeLike = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "tri") M = M if M is not None else N dtype = dtype or float32 @@ -2574,29 +2587,32 @@ def tri(N, M=None, k=0, dtype=None): @_wraps(np.tril) @partial(jit, static_argnames=('k',)) -def tril(m, k=0): +def tril(m: ArrayLike, k: int = 0) -> Array: _check_arraylike("tril", m) m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") - mask = tri(*m_shape[-2:], k=k, dtype=bool) + N, M = m_shape[-2:] + mask = tri(N, M, k=k, dtype=bool) return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) @_wraps(np.triu, update_doc=False) @partial(jit, static_argnames=('k',)) -def triu(m, k=0): +def triu(m: ArrayLike, k: int = 0) -> Array: _check_arraylike("triu", m) m_shape = shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") - mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) + N, M = m_shape[-2:] + mask = tri(N, M, k=k - 1, dtype=bool) return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) @_wraps(np.trace, skip_params=['out']) @partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype')) -def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None): +def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1, + dtype: Optional[DTypeLike] = None, out: None = None) -> Array: _check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") @@ -2636,13 +2652,15 @@ mask_indices = _wrap_indices_function(np.mask_indices) @_wraps(np.triu_indices_from) -def triu_indices_from(arr, k=0): - return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1]) +def triu_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array]: + arr_shape = shape(arr) + return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1]) @_wraps(np.tril_indices_from) -def tril_indices_from(arr, k=0): - return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1]) +def tril_indices_from(arr: ArrayLike, k: int = 0) -> Tuple[Array]: + arr_shape = shape(arr) + return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1]) @_wraps(np.diag_indices) @@ -4250,7 +4268,7 @@ def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'): idx = tuple(idx) + colons return idx -def _static_idx(idx: slice, size: core.DimSize): +def _static_idx(idx: slice, size: DimSize): """Helper function to compute the static slice start/limit/stride values.""" if isinstance(size, int): start, stop, step = idx.indices(size) @@ -4318,11 +4336,11 @@ def kaiser(M: int, beta: ArrayLike) -> Array: return i0(beta * sqrt(1 - ((n - alpha) / alpha) ** 2)) / i0(beta) -def _gcd_cond_fn(xs): +def _gcd_cond_fn(xs: Tuple[Array, Array]) -> Array: x1, x2 = xs return any(x2 != 0) -def _gcd_body_fn(xs): +def _gcd_body_fn(xs: Tuple[Array, Array]) -> Tuple[Array, Array]: x1, x2 = xs x1, x2 = (where(x2 != 0, x2, x1), where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) @@ -4330,7 +4348,7 @@ def _gcd_body_fn(xs): @_wraps(np.gcd, module='numpy') @jit -def gcd(x1, x2): +def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: _check_arraylike("gcd", x1, x2) x1, x2 = _promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), integer): @@ -4342,7 +4360,7 @@ def gcd(x1, x2): @_wraps(np.lcm, module='numpy') @jit -def lcm(x1, x2): +def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: _check_arraylike("lcm", x1, x2) x1, x2 = _promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), integer): @@ -4353,34 +4371,37 @@ def lcm(x1, x2): @_wraps(np.extract) -def extract(condition, arr): +def extract(condition: ArrayLike, arr: ArrayLike) -> Array: return compress(ravel(condition), ravel(arr)) @_wraps(np.compress, skip_params=['out']) -def compress(condition, a, axis: Optional[int] = None, out=None): +def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = None, + out: None = None) -> Array: _check_arraylike("compress", condition, a) + condition_arr = asarray(condition).astype(bool) if out is not None: raise NotImplementedError("The 'out' argument to jnp.compress is not supported.") - if ndim(condition) != 1: + if condition_arr.ndim != 1: raise ValueError("condition must be a 1D array") - condition = asarray(condition).astype(bool) if axis is None: axis = 0 - a = ravel(a) + arr = ravel(a) else: - a = moveaxis(a, axis, 0) - condition, extra = condition[:a.shape[0]], condition[a.shape[0]:] + arr = moveaxis(a, axis, 0) + condition_arr, extra = condition_arr[:arr.shape[0]], condition_arr[arr.shape[0]:] if any(extra): raise ValueError("condition contains entries that are out of bounds") - a = a[:condition.shape[0]] - return moveaxis(a[condition], 0, axis) + arr = arr[:condition_arr.shape[0]] + return moveaxis(arr[condition_arr], 0, axis) @_wraps(np.cov) @partial(jit, static_argnames=('rowvar', 'bias', 'ddof')) -def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, - aweights=None): +def cov(m: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True, + bias: bool = False, ddof: Optional[int] = None, + fweights: Optional[ArrayLike] = None, + aweights: Optional[ArrayLike] = None) -> Array: if y is not None: m, y = _promote_args_inexact("cov", m, y) if y.ndim > 2: @@ -4398,14 +4419,14 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, return array([]).reshape(0, 0) if y is not None: - y = atleast_2d(y) - if not rowvar and y.shape[0] != 1: - y = y.T - X = concatenate((X, y), axis=0) + y_arr = atleast_2d(y) + if not rowvar and y_arr.shape[0] != 1: + y_arr = y_arr.T + X = concatenate((X, y_arr), axis=0) if ddof is None: ddof = 1 if bias == 0 else 0 - w = None + w: Optional[Array] = None if fweights is not None: _check_arraylike("cov", fweights) if ndim(fweights) > 1: @@ -4424,7 +4445,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, raise RuntimeError("incompatible numbers of samples and aweights") # Ensure positive aweights: note that numpy raises an error for negative aweights. aweights = abs(aweights) - w = aweights if w is None else w * aweights + w = asarray(aweights) if w is None else w * asarray(aweights) avg, w_sum = average(X, axis=1, weights=w, returned=True) w_sum = w_sum[0] @@ -4445,7 +4466,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, @_wraps(np.corrcoef) @partial(jit, static_argnames=('rowvar',)) -def corrcoef(x, y=None, rowvar=True): +def corrcoef(x: ArrayLike, y: Optional[ArrayLike] = None, rowvar: bool = True) -> Array: _check_arraylike("corrcoef", x) c = cov(x, y, rowvar) if len(shape(c)) == 0: @@ -4467,9 +4488,9 @@ def corrcoef(x, y=None, rowvar=True): @_wraps(np.quantile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) -def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - overwrite_input=False, method="linear", keepdims=False, - interpolation=None): +def quantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: _check_arraylike("quantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.quantile does not support overwrite_input=True or " @@ -4478,14 +4499,14 @@ def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, if interpolation is not None: warnings.warn("The interpolation= argument to 'quantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) - return _quantile(a, q, axis, interpolation or method, keepdims, False) + return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, False) @_wraps(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) -def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, - out=None, overwrite_input=False, method="linear", - keepdims=False, interpolation=None): +def nanquantile(a: ArrayLike, q: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: _check_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " @@ -4494,9 +4515,10 @@ def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, if interpolation is not None: warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " "Use 'method=' instead.", DeprecationWarning) - return _quantile(a, q, axis, interpolation or method, keepdims, True) + return _quantile(asarray(a), asarray(q), axis, interpolation or method, keepdims, True) -def _quantile(a, q, axis, interpolation, keepdims, squash_nans): +def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], + interpolation: str, keepdims: bool, squash_nans: bool) -> Array: if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " "'midpoint', or 'nearest'") @@ -4524,7 +4546,6 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans): do_not_touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx not in axis) touch_shape = tuple(x for idx,x in enumerate(shape(a)) if idx in axis) a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions) - keepdim = tuple(keepdim) axis = _canonicalize_axis(-1, ndim(a)) else: axis = _canonicalize_axis(axis, ndim(a)) @@ -4614,13 +4635,13 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans): raise ValueError(f"interpolation={interpolation!r} not recognized") if keepdims and keepdim: if q_ndim > 0: - keepdim = (shape(q)[0],) + keepdim + keepdim = [shape(q)[0], *keepdim] result = reshape(result, keepdim) return lax.convert_element_type(result, a.dtype) @partial(vectorize, excluded={0, 2, 3}) -def _searchsorted_via_scan(sorted_arr, query, side, dtype): +def _searchsorted_via_scan(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: op = _sort_le_comparator if side == 'left' else _sort_lt_comparator def body_fun(_, state): low, high = state @@ -4632,7 +4653,7 @@ def _searchsorted_via_scan(sorted_arr, query, side, dtype): return lax.fori_loop(0, n_levels, body_fun, init)[1] -def _searchsorted_via_sort(sorted_arr, query, side, dtype): +def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: working_dtype = int32 if sorted_arr.size + query.size < np.iinfo(np.int32).max else int64 def _rank(x): idx = lax.iota(working_dtype, len(x)) @@ -4653,7 +4674,8 @@ def _searchsorted_via_sort(sorted_arr, query, side, dtype): 'sort' is often more performant on accelerator backends like GPU and TPU (particularly when ``v`` is very large).""")) @partial(jit, static_argnames=('side', 'sorter', 'method')) -def searchsorted(a, v, side='left', sorter=None, *, method='scan'): +def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', + sorter: None = None, *, method: str = 'scan') -> Array: _check_arraylike("searchsorted", a, v) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " @@ -4670,22 +4692,23 @@ def searchsorted(a, v, side='left', sorter=None, *, method='scan'): if len(a) == 0: return zeros_like(v, dtype=dtype) impl = _searchsorted_via_scan if method == 'scan' else _searchsorted_via_sort - return impl(a, v, side, dtype) + return impl(asarray(a), asarray(v), side, dtype) @_wraps(np.digitize) @partial(jit, static_argnames=('right',)) -def digitize(x, bins, right=False): +def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: _check_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") - if ndim(bins) != 1: + bins_arr = asarray(bins) + if bins_arr.ndim != 1: raise ValueError(f"digitize: bins must be a 1-dimensional array; got bins={bins}") - if len(bins) == 0: + if bins_arr.shape[0] == 0: return zeros(x, dtype=dtypes.canonicalize_dtype(int_)) side = 'right' if not right else 'left' return where( - bins[-1] >= bins[0], - searchsorted(bins, x, side=side), - len(bins) - searchsorted(bins[::-1], x, side=side) + bins_arr[-1] >= bins_arr[0], + searchsorted(bins_arr, x, side=side), + len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side) ) _PIECEWISE_DOC = """\ @@ -4695,9 +4718,10 @@ See the :func:`jax.lax.switch` documentation for more information. """ @_wraps(np.piecewise, lax_description=_PIECEWISE_DOC) -def piecewise(x, condlist, funclist, *args, **kw): +def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]], + funclist: List[Union[ArrayLike, Callable[..., Array]]], + *args, **kw) -> Array: _check_arraylike("piecewise", x) - condlist = array(condlist, dtype=bool_) nc, nf = len(condlist), len(funclist) if nf == nc + 1: funclist = funclist[-1:] + funclist[:-1] @@ -4707,14 +4731,16 @@ def piecewise(x, condlist, funclist, *args, **kw): raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}") consts = {i: c for i, c in enumerate(funclist) if not callable(c)} funcs = {i: f for i, f in enumerate(funclist) if callable(f)} - return _piecewise(x, condlist, consts, + return _piecewise(asarray(x), asarray(condlist, dtype=bool_), consts, frozenset(funcs.items()), # dict is not hashable. *args, **kw) @partial(jit, static_argnames=['funcs']) -def _piecewise(x, condlist, consts, funcs, *args, **kw): - funcs = dict(funcs) - funclist = [consts.get(i, funcs.get(i)) for i in range(len(condlist) + 1)] +def _piecewise(x: Array, condlist: Array, consts: Dict[int, ArrayLike], + funcs: FrozenSet[Tuple[int, Callable[..., Array]]], + *args, **kw) -> Array: + funcdict = dict(funcs) + funclist = [consts.get(i, funcdict.get(i)) for i in range(len(condlist) + 1)] indices = argmax(cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0) dtype = _dtype(x) def _call(f): @@ -4728,9 +4754,10 @@ def _piecewise(x, condlist, consts, funcs, *args, **kw): @_wraps(np.percentile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) -def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, - out=None, overwrite_input=False, method="linear", - keepdims=False, interpolation=None): +def percentile(a: ArrayLike, q: ArrayLike, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: _check_arraylike("percentile", a, q) q, = _promote_dtypes_inexact(q) return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, @@ -4739,9 +4766,10 @@ def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, @_wraps(np.nanpercentile, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) -def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, - out=None, overwrite_input=False, method="linear", - keepdims=False, interpolation=None): +def nanpercentile(a: ArrayLike, q: ArrayLike, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, method: str = "linear", + keepdims: bool = False, interpolation: None = None) -> Array: _check_arraylike("nanpercentile", a, q) q = true_divide(q, float32(100.0)) return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, @@ -4750,23 +4778,25 @@ def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, @_wraps(np.median, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) -def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - overwrite_input=False, keepdims=False): +def median(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, + keepdims: bool = False) -> Array: _check_arraylike("median", a) return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') @_wraps(np.nanmedian, skip_params=['out', 'overwrite_input']) @partial(jit, static_argnames=('axis', 'overwrite_input', 'keepdims')) -def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - overwrite_input=False, keepdims=False): +def nanmedian(a: ArrayLike, axis: Optional[Union[int, Tuple[int, ...]]] = None, + out: None = None, overwrite_input: bool = False, + keepdims: bool = False) -> Array: _check_arraylike("nanmedian", a) return nanquantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') -def _astype(arr, dtype): +def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: """Copy the array and cast to a specified dtype. This is implemeted via :func:`jax.lax.convert_element_type`, which may @@ -4780,19 +4810,21 @@ def _astype(arr, dtype): return lax.convert_element_type(arr, dtype) -def _nbytes(arr): +def _nbytes(arr: ArrayLike) -> int: return size(arr) * _dtype(arr).itemsize -def _itemsize(arr): +def _itemsize(arr: ArrayLike) -> int: return _dtype(arr).itemsize -def _clip(number, min=None, max=None, out=None): # noqa: F811 +def _clip(number: ArrayLike, + min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, # noqa: F811 + out: None = None) -> Array: return clip(number, a_min=min, a_max=max, out=out) -def _view(arr, dtype=None, type=None): +def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "view") if type is not None: raise NotImplementedError("`type` argument of array.view()") @@ -4814,7 +4846,8 @@ def _view(arr, dtype=None, type=None): if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0: raise ValueError("When changing to a larger dtype, its size must be a divisor " "of the total size in bytes of the last axis of the array.") - byte_dtypes = {8: uint8, 16: uint16, 32: uint32, 64: uint64} + byte_dtypes: Dict[int, DType] = {8: np.dtype('uint8'), 16: np.dtype('uint16'), + 32: np.dtype('uint32'), 64: np.dtype('uint64')} if nbits_in not in byte_dtypes: raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}") if nbits_out not in byte_dtypes: @@ -4883,15 +4916,15 @@ def _unimplemented_setitem(self, i, x): "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html") raise TypeError(msg.format(type(self))) -def _operator_round(number, ndigits=None): +def _operator_round(number: ArrayLike, ndigits: Optional[int] = None) -> Array: out = round(number, decimals=ndigits or 0) # If `ndigits` is None, for a builtin float round(7.5) returns an integer. return out.astype(int) if ndigits is None else out -def _copy(self): +def _copy(self: Array) -> Array: return self.copy() -def _deepcopy(self, memo): +def _deepcopy(self: Array, memo: Any) -> Array: del memo # unused return self.copy() @@ -4970,22 +5003,23 @@ def __array_module__(self, types): return NotImplemented -def _compress_method(a, condition, axis=None, out=None): +def _compress_method(a: ArrayLike, condition: ArrayLike, + axis: Optional[int] = None, out: None = None) -> Array: return compress(condition, a, axis, out) @core.stash_axis_env() @partial(jit, static_argnums=(1,2,3)) -def _multi_slice(arr, +def _multi_slice(arr: ArrayLike, start_indices: Tuple[Tuple[int, ...]], limit_indices: Tuple[Tuple[int, ...]], - removed_dims: Tuple[Tuple[int, ...]]): + removed_dims: Tuple[Tuple[int, ...]]) -> List[Array]: """Extracts multiple slices from `arr`. This is used to shard DeviceArray arguments to pmap. It's implemented as a DeviceArray method here to avoid circular imports. """ - results = [] + results: List[Array] = [] for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims): sliced = lax.slice(arr, starts, limits) if removed: @@ -4996,7 +5030,7 @@ def _multi_slice(arr, # The next two functions are related to iter(device_array), implemented here to # avoid circular imports. @jit -def _unstack(x): +def _unstack(x: Array) -> List[Array]: return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] setattr(device_array.DeviceArray, "_unstack", _unstack) setattr(ArrayImpl, '_unstack', _unstack)