diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index ed483a889..56be67932 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -15,6 +15,7 @@ from functools import partial import operator +from typing import Optional, Tuple, Union from jax import core from jax import jit @@ -26,11 +27,12 @@ from jax._src.numpy.lax_numpy import ( vander, zeros) from jax._src.numpy import linalg from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps +from jax._src.typing import Array, ArrayLike import numpy as np @jit -def _roots_no_zeros(p): +def _roots_no_zeros(p: Array) -> Array: # build companion matrix and find its eigenvalues (the roots) if p.size < 2: return array([], dtype=dtypes.to_complex_dtype(p.dtype)) @@ -40,7 +42,7 @@ def _roots_no_zeros(p): @jit -def _roots_with_zeros(p, num_leading_zeros): +def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array: # Avoid lapack errors when p is all zero p = _where(len(p) == num_leading_zeros, 1.0, p) # Roll any leading zeros to the end & compute the roots @@ -77,23 +79,23 @@ strip_zeros : bool, default=True ``strip_zeros`` must be set to ``False`` for the function to be compatible with :func:`jax.jit` and other JAX transformations. """) -def roots(p, *, strip_zeros=True): +def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: _check_arraylike("roots", p) - p = atleast_1d(*_promote_dtypes_inexact(p)) - if p.ndim != 1: + p_arr = atleast_1d(*_promote_dtypes_inexact(p)) + if p_arr.ndim != 1: raise ValueError("Input must be a rank-1 array.") - if p.size < 2: - return array([], dtype=dtypes.to_complex_dtype(p.dtype)) - num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) + if p_arr.size < 2: + return array([], dtype=dtypes.to_complex_dtype(p_arr.dtype)) + num_leading_zeros = _where(all(p_arr == 0), len(p_arr), argmin(p_arr == 0)) if strip_zeros: num_leading_zeros = core.concrete_or_error(int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " "will be result in some returned roots being set to NaN.") - return _roots_no_zeros(p[num_leading_zeros:]) + return _roots_no_zeros(p_arr[num_leading_zeros:]) else: - return _roots_with_zeros(p, num_leading_zeros) + return _roots_with_zeros(p_arr, num_leading_zeros) _POLYFIT_DOC = """\ @@ -102,7 +104,9 @@ Also, it works best on rcond <= 10e-3 values. """ @_wraps(np.polyfit, lax_description=_POLYFIT_DOC) @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) -def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): +def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None, + full: bool = False, w: Optional[Array] = None, cov: bool = False + ) -> Union[Array, Tuple[Array, ...]]: _check_arraylike("polyfit", x, y) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 @@ -147,7 +151,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): c = (c.T/scale).T # broadcast scale coefficients if full: - return c, resids, rank, s, rcond + return c, resids, rank, s, asarray(rcond) elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) @@ -181,7 +185,7 @@ jax returns an array with a complex dtype in such cases. @_wraps(np.poly, lax_description=_POLY_DOC) @jit -def poly(seq_of_zeros): +def poly(seq_of_zeros: Array) -> Array: _check_arraylike('poly', seq_of_zeros) seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros) seq_of_zeros = atleast_1d(seq_of_zeros) @@ -215,7 +219,7 @@ improve runtime performance on accelerators, at the cost of increased compilation time. """) @partial(jit, static_argnames=['unroll']) -def polyval(p, x, *, unroll=16): +def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: _check_arraylike("polyval", p, x) p, x = _promote_dtypes_inexact(p, x) shape = lax.broadcast_shapes(p.shape[1:], x.shape) @@ -225,7 +229,7 @@ def polyval(p, x, *, unroll=16): @_wraps(np.polyadd) @jit -def polyadd(a1, a2): +def polyadd(a1: Array, a2: Array) -> Array: _check_arraylike("polyadd", a1, a2) a1, a2 = _promote_dtypes(a1, a2) if a2.shape[0] <= a1.shape[0]: @@ -236,17 +240,17 @@ def polyadd(a1, a2): @_wraps(np.polyint) @partial(jit, static_argnames=('m',)) -def polyint(p, m=1, k=None): +def polyint(p: Array, m: int = 1, k: Optional[int] = None) -> Array: m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k _check_arraylike("polyint", p, k) - p, k = _promote_dtypes_inexact(p, k) + p, k_arr = _promote_dtypes_inexact(p, k) if m < 0: raise ValueError("Order of integral must be positive (see polyder)") - k = atleast_1d(k) - if len(k) == 1: - k = full((m,), k[0]) - if k.shape != (m,): + k_arr = atleast_1d(k_arr) + if len(k_arr) == 1: + k_arr = full((m,), k_arr[0]) + if k_arr.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: return p @@ -254,12 +258,12 @@ def polyint(p, m=1, k=None): grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis] - arange(m, dtype=p.dtype)[:, np.newaxis]) coeff = maximum(1, grid).prod(0)[::-1] - return true_divide(concatenate((p, k)), coeff) + return true_divide(concatenate((p, k_arr)), coeff) @_wraps(np.polyder) @partial(jit, static_argnames=('m',)) -def polyder(p, m=1): +def polyder(p: Array, m: int = 1) -> Array: _check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") p, = _promote_dtypes_inexact(p) @@ -281,38 +285,37 @@ JAX backends. The result may lead to inconsistent output shapes when trim_leadin """ @_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC) -def polymul(a1, a2, *, trim_leading_zeros=False): +def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: _check_arraylike("polymul", a1, a2) - a1, a2 = _promote_dtypes_inexact(a1, a2) - if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1): - a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f') - if len(a1) == 0: - a1 = asarray([0], dtype=a2.dtype) - if len(a2) == 0: - a2 = asarray([0], dtype=a1.dtype) - return convolve(a1, a2, mode='full') + a1_arr, a2_arr = _promote_dtypes_inexact(a1, a2) + if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1): + a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f') + if len(a1_arr) == 0: + a1_arr = asarray([0], dtype=a2_arr.dtype) + if len(a2_arr) == 0: + a2_arr = asarray([0], dtype=a1_arr.dtype) + return convolve(a1_arr, a2_arr, mode='full') @_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC) -def polydiv(u, v, *, trim_leading_zeros=False): +def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> Tuple[Array, Array]: _check_arraylike("polydiv", u, v) - u, v = _promote_dtypes_inexact(u, v) - m = len(u) - 1 - n = len(v) - 1 - scale = 1. / v[0] - q = zeros(max(m - n + 1, 1), dtype = u.dtype) # force same dtype + u_arr, v_arr = _promote_dtypes_inexact(u, v) + m = len(u_arr) - 1 + n = len(v_arr) - 1 + scale = 1. / v_arr[0] + q: Array = zeros(max(m - n + 1, 1), dtype = u_arr.dtype) # force same dtype for k in range(0, m-n+1): - d = scale * u[k] + d = scale * u_arr[k] q = q.at[k].set(d) - u = u.at[k:k+n+1].add(-d*v) + u_arr = u_arr.at[k:k+n+1].add(-d*v_arr) if trim_leading_zeros: # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy - return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f') - else: - return q, u + u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f') + return q, u_arr @_wraps(np.polysub) @jit -def polysub(a1, a2): +def polysub(a1: Array, a2: Array) -> Array: _check_arraylike("polysub", a1, a2) a1, a2 = _promote_dtypes(a1, a2) return polyadd(a1, -a2)