From a43994d4645684cf67d645a44063b49b0917f9c3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 22 Jun 2024 07:09:20 -0700 Subject: [PATCH] Fix type annotations for jnp.poly* functions --- jax/_src/numpy/polynomial.py | 115 +++++++++++++++++++---------------- jax/numpy/__init__.pyi | 16 ++--- 2 files changed, 72 insertions(+), 59 deletions(-) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 29c36ab3b..45595c438 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -99,6 +99,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: """ check_arraylike("roots", p) p_arr = atleast_1d(promote_dtypes_inexact(p)[0]) + del p if p_arr.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p_arr.size < 2: @@ -116,8 +117,8 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array: @partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov')) -def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, - full: bool = False, w: Array | None = None, cov: bool = False +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, + full: bool = False, w: ArrayLike | None = None, cov: bool = False ) -> Array | tuple[Array, ...]: r"""Least squares polynomial fit to data. @@ -217,42 +218,47 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, >>> p.shape, C.shape ((3, 3), (3, 3, 1)) """ - check_arraylike("polyfit", x, y) + if w is None: + check_arraylike("polyfit", x, y) + else: + check_arraylike("polyfit", x, y, w) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments + x_arr, y_arr = asarray(x), asarray(y) + del x, y if deg < 0: raise ValueError("expected deg >= 0") - if x.ndim != 1: + if x_arr.ndim != 1: raise TypeError("expected 1D vector for x") - if x.size == 0: + if x_arr.size == 0: raise TypeError("expected non-empty vector for x") - if y.ndim < 1 or y.ndim > 2: + if y_arr.ndim < 1 or y_arr.ndim > 2: raise TypeError("expected 1D or 2D array for y") - if x.shape[0] != y.shape[0]: + if x_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: - rcond = len(x) * finfo(x.dtype).eps + rcond = len(x_arr) * finfo(x_arr.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x - lhs = vander(x, order) - rhs = y + lhs = vander(x_arr, order) + rhs = y_arr # apply weighting if w is not None: - check_arraylike("polyfit", w) w, = promote_dtypes_inexact(w) - if w.ndim != 1: + w_arr = asarray(w) + if w_arr.ndim != 1: raise TypeError("expected a 1-d array for weights") - if w.shape[0] != y.shape[0]: + if w_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected w and y to have the same length") - lhs *= w[:, np.newaxis] + lhs *= w_arr[:, np.newaxis] if rhs.ndim == 2: - rhs *= w[:, np.newaxis] + rhs *= w_arr[:, np.newaxis] else: - rhs *= w + rhs *= w_arr # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) @@ -268,12 +274,12 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, if cov == "unscaled": fac = 1 else: - if len(x) <= order: + if len(x_arr) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") - fac = resids / (len(x) - order) + fac = resids / (len(x_arr) - order) fac = fac[0] #making np.array() of shape (1,) to int - if y.ndim == 1: + if y_arr.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac @@ -282,7 +288,7 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, @jit -def poly(seq_of_zeros: Array) -> Array: +def poly(seq_of_zeros: ArrayLike) -> Array: r"""Returns the coefficients of a polynomial for the given sequence of roots. JAX implementation of :func:`numpy.poly`. @@ -340,30 +346,31 @@ def poly(seq_of_zeros: Array) -> Array: """ check_arraylike('poly', seq_of_zeros) seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros) - seq_of_zeros = atleast_1d(seq_of_zeros) + seq_of_zeros_arr = atleast_1d(seq_of_zeros) + del seq_of_zeros - sh = seq_of_zeros.shape + sh = seq_of_zeros_arr.shape if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: # import at runtime to avoid circular import from jax._src.numpy import linalg - seq_of_zeros = linalg.eigvals(seq_of_zeros) + seq_of_zeros_arr = linalg.eigvals(seq_of_zeros_arr) - if seq_of_zeros.ndim != 1: + if seq_of_zeros_arr.ndim != 1: raise ValueError("input must be 1d or non-empty square 2d array.") - dt = seq_of_zeros.dtype - if len(seq_of_zeros) == 0: + dt = seq_of_zeros_arr.dtype + if len(seq_of_zeros_arr) == 0: return ones((), dtype=dt) a = ones((1,), dtype=dt) - for k in range(len(seq_of_zeros)): - a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full') + for k in range(len(seq_of_zeros_arr)): + a = convolve(a, array([1, -seq_of_zeros_arr[k]], dtype=dt), mode='full') return a @partial(jit, static_argnames=['unroll']) -def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: r"""Evaluates the polynomial at specific values. JAX implementations of :func:`numpy.polyval`. @@ -417,25 +424,27 @@ def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: [ 8., 34., 76.]], dtype=float32) """ check_arraylike("polyval", p, x) - p, x = promote_dtypes_inexact(p, x) - shape = lax.broadcast_shapes(p.shape[1:], x.shape) - y = lax.full_like(x, 0, shape=shape, dtype=x.dtype) - y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll) + p_arr, x_arr = promote_dtypes_inexact(p, x) + del p, x + shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) + y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) + y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) return y @implements(np.polyadd) @jit -def polyadd(a1: Array, a2: Array) -> Array: +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polyadd", a1, a2) - a1, a2 = promote_dtypes(a1, a2) - if a2.shape[0] <= a1.shape[0]: - return a1.at[-a2.shape[0]:].add(a2) + a1_arr, a2_arr = promote_dtypes(a1, a2) + del a1, a2 + if a2_arr.shape[0] <= a1_arr.shape[0]: + return a1_arr.at[-a2_arr.shape[0]:].add(a2_arr) else: - return a2.at[-a1.shape[0]:].add(a1) + return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr) @partial(jit, static_argnames=('m',)) -def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: +def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array: r"""Returns the coefficients of the integration of specified order of a polynomial. JAX implementation of :func:`numpy.polyint`. @@ -484,7 +493,8 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k check_arraylike("polyint", p, k) - p, k_arr = promote_dtypes_inexact(p, k) + p_arr, k_arr = promote_dtypes_inexact(p, k) + del p, k if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k_arr = atleast_1d(k_arr) @@ -493,16 +503,16 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: if k_arr.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: - return p + return p_arr else: - grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]) + grid = (arange(len(p_arr) + m, dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]) coeff = maximum(1, grid).prod(0)[::-1] - return true_divide(concatenate((p, k_arr)), coeff) + return true_divide(concatenate((p_arr, k_arr)), coeff) @partial(jit, static_argnames=('m',)) -def polyder(p: Array, m: int = 1) -> Array: +def polyder(p: ArrayLike, m: int = 1) -> Array: r"""Returns the coefficients of the derivative of specified order of a polynomial. JAX implementation of :func:`numpy.polyder`. @@ -540,14 +550,15 @@ def polyder(p: Array, m: int = 1) -> Array: """ check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") - p, = promote_dtypes_inexact(p) + p_arr, = promote_dtypes_inexact(p) + del p if m < 0: raise ValueError("Order of derivative must be positive") if m == 0: - return p - coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis] - - arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0) - return p[:-m] * coeff[::-1] + return p_arr + coeff = (arange(m, len(p_arr), dtype=p_arr.dtype)[np.newaxis] + - arange(m, dtype=p_arr.dtype)[:, np.newaxis]).prod(0) + return p_arr[:-m] * coeff[::-1] _LEADING_ZEROS_DOC = """\ @@ -562,6 +573,7 @@ JAX backends. The result may lead to inconsistent output shapes when trim_leadin def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array: check_arraylike("polymul", a1, a2) a1_arr, a2_arr = promote_dtypes_inexact(a1, a2) + del a1, a2 if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1): a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f') if len(a1_arr) == 0: @@ -574,6 +586,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]: check_arraylike("polydiv", u, v) u_arr, v_arr = promote_dtypes_inexact(u, v) + del u, v m = len(u_arr) - 1 n = len(v_arr) - 1 scale = 1. / v_arr[0] @@ -589,7 +602,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> @implements(np.polysub) @jit -def polysub(a1: Array, a2: Array) -> Array: +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: check_arraylike("polysub", a1, a2) a1, a2 = promote_dtypes(a1, a2) return polyadd(a1, -a2) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 3e306a24b..0e29b2892 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -669,17 +669,17 @@ def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]], *args, **kw) -> Array: ... def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: builtins.bool = ...) -> Array: ... -def poly(seq_of_zeros: Array) -> Array: ... -def polyadd(a1: Array, a2: Array) -> Array: ... -def polyder(p: Array, m: int = ...) -> Array: ... +def poly(seq_of_zeros: ArrayLike) -> Array: ... +def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyder(p: ArrayLike, m: int = ...) -> Array: ... def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> tuple[Array, Array]: ... -def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = ..., - full: builtins.bool = ..., w: Optional[Array] = ..., cov: builtins.bool = ... +def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: Optional[float] = ..., + full: builtins.bool = ..., w: Optional[ArrayLike] = ..., cov: builtins.bool = ... ) -> Union[Array, tuple[Array, ...]]: ... -def polyint(p: Array, m: int = ..., k: Optional[int] = ...) -> Array: ... +def polyint(p: ArrayLike, m: int = ..., k: Union[int, ArrayLike, None] = ...) -> Array: ... def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> Array: ... -def polysub(a1: Array, a2: Array) -> Array: ... -def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ... +def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: ... +def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = ...) -> Array: ... def positive(x: ArrayLike, /) -> Array: ... def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ...