Fix type annotations for jnp.poly* functions

This commit is contained in:
Jake VanderPlas 2024-06-22 07:09:20 -07:00
parent e119fe933b
commit a43994d464
2 changed files with 72 additions and 59 deletions

View File

@ -99,6 +99,7 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
"""
check_arraylike("roots", p)
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
del p
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p_arr.size < 2:
@ -116,8 +117,8 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
full: bool = False, w: Array | None = None, cov: bool = False
def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
full: bool = False, w: ArrayLike | None = None, cov: bool = False
) -> Array | tuple[Array, ...]:
r"""Least squares polynomial fit to data.
@ -217,42 +218,47 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
>>> p.shape, C.shape
((3, 3), (3, 3, 1))
"""
if w is None:
check_arraylike("polyfit", x, y)
else:
check_arraylike("polyfit", x, y, w)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
x_arr, y_arr = asarray(x), asarray(y)
del x, y
if deg < 0:
raise ValueError("expected deg >= 0")
if x.ndim != 1:
if x_arr.ndim != 1:
raise TypeError("expected 1D vector for x")
if x.size == 0:
if x_arr.size == 0:
raise TypeError("expected non-empty vector for x")
if y.ndim < 1 or y.ndim > 2:
if y_arr.ndim < 1 or y_arr.ndim > 2:
raise TypeError("expected 1D or 2D array for y")
if x.shape[0] != y.shape[0]:
if x_arr.shape[0] != y_arr.shape[0]:
raise TypeError("expected x and y to have same length")
# set rcond
if rcond is None:
rcond = len(x) * finfo(x.dtype).eps
rcond = len(x_arr) * finfo(x_arr.dtype).eps
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
# set up least squares equation for powers of x
lhs = vander(x, order)
rhs = y
lhs = vander(x_arr, order)
rhs = y_arr
# apply weighting
if w is not None:
check_arraylike("polyfit", w)
w, = promote_dtypes_inexact(w)
if w.ndim != 1:
w_arr = asarray(w)
if w_arr.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w.shape[0] != y.shape[0]:
if w_arr.shape[0] != y_arr.shape[0]:
raise TypeError("expected w and y to have the same length")
lhs *= w[:, np.newaxis]
lhs *= w_arr[:, np.newaxis]
if rhs.ndim == 2:
rhs *= w[:, np.newaxis]
rhs *= w_arr[:, np.newaxis]
else:
rhs *= w
rhs *= w_arr
# scale lhs to improve condition number and solve
scale = sqrt((lhs*lhs).sum(axis=0))
@ -268,12 +274,12 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
if cov == "unscaled":
fac = 1
else:
if len(x) <= order:
if len(x_arr) <= order:
raise ValueError("the number of data points must exceed order "
"to scale the covariance matrix")
fac = resids / (len(x) - order)
fac = resids / (len(x_arr) - order)
fac = fac[0] #making np.array() of shape (1,) to int
if y.ndim == 1:
if y_arr.ndim == 1:
return c, Vbase * fac
else:
return c, Vbase[:, :, np.newaxis] * fac
@ -282,7 +288,7 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
@jit
def poly(seq_of_zeros: Array) -> Array:
def poly(seq_of_zeros: ArrayLike) -> Array:
r"""Returns the coefficients of a polynomial for the given sequence of roots.
JAX implementation of :func:`numpy.poly`.
@ -340,30 +346,31 @@ def poly(seq_of_zeros: Array) -> Array:
"""
check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
seq_of_zeros_arr = atleast_1d(seq_of_zeros)
del seq_of_zeros
sh = seq_of_zeros.shape
sh = seq_of_zeros_arr.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from jax._src.numpy import linalg
seq_of_zeros = linalg.eigvals(seq_of_zeros)
seq_of_zeros_arr = linalg.eigvals(seq_of_zeros_arr)
if seq_of_zeros.ndim != 1:
if seq_of_zeros_arr.ndim != 1:
raise ValueError("input must be 1d or non-empty square 2d array.")
dt = seq_of_zeros.dtype
if len(seq_of_zeros) == 0:
dt = seq_of_zeros_arr.dtype
if len(seq_of_zeros_arr) == 0:
return ones((), dtype=dt)
a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros)):
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
for k in range(len(seq_of_zeros_arr)):
a = convolve(a, array([1, -seq_of_zeros_arr[k]], dtype=dt), mode='full')
return a
@partial(jit, static_argnames=['unroll'])
def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
r"""Evaluates the polynomial at specific values.
JAX implementations of :func:`numpy.polyval`.
@ -417,25 +424,27 @@ def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
[ 8., 34., 76.]], dtype=float32)
"""
check_arraylike("polyval", p, x)
p, x = promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
p_arr, x_arr = promote_dtypes_inexact(p, x)
del p, x
shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape)
y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype)
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll)
return y
@implements(np.polyadd)
@jit
def polyadd(a1: Array, a2: Array) -> Array:
def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
check_arraylike("polyadd", a1, a2)
a1, a2 = promote_dtypes(a1, a2)
if a2.shape[0] <= a1.shape[0]:
return a1.at[-a2.shape[0]:].add(a2)
a1_arr, a2_arr = promote_dtypes(a1, a2)
del a1, a2
if a2_arr.shape[0] <= a1_arr.shape[0]:
return a1_arr.at[-a2_arr.shape[0]:].add(a2_arr)
else:
return a2.at[-a1.shape[0]:].add(a1)
return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr)
@partial(jit, static_argnames=('m',))
def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array:
r"""Returns the coefficients of the integration of specified order of a polynomial.
JAX implementation of :func:`numpy.polyint`.
@ -484,7 +493,8 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
check_arraylike("polyint", p, k)
p, k_arr = promote_dtypes_inexact(p, k)
p_arr, k_arr = promote_dtypes_inexact(p, k)
del p, k
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k_arr = atleast_1d(k_arr)
@ -493,16 +503,16 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
if k_arr.shape != (m,):
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p
return p_arr
else:
grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis]
- arange(m, dtype=p.dtype)[:, np.newaxis])
grid = (arange(len(p_arr) + m, dtype=p_arr.dtype)[np.newaxis]
- arange(m, dtype=p_arr.dtype)[:, np.newaxis])
coeff = maximum(1, grid).prod(0)[::-1]
return true_divide(concatenate((p, k_arr)), coeff)
return true_divide(concatenate((p_arr, k_arr)), coeff)
@partial(jit, static_argnames=('m',))
def polyder(p: Array, m: int = 1) -> Array:
def polyder(p: ArrayLike, m: int = 1) -> Array:
r"""Returns the coefficients of the derivative of specified order of a polynomial.
JAX implementation of :func:`numpy.polyder`.
@ -540,14 +550,15 @@ def polyder(p: Array, m: int = 1) -> Array:
"""
check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = promote_dtypes_inexact(p)
p_arr, = promote_dtypes_inexact(p)
del p
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
return p
coeff = (arange(m, len(p), dtype=p.dtype)[np.newaxis]
- arange(m, dtype=p.dtype)[:, np.newaxis]).prod(0)
return p[:-m] * coeff[::-1]
return p_arr
coeff = (arange(m, len(p_arr), dtype=p_arr.dtype)[np.newaxis]
- arange(m, dtype=p_arr.dtype)[:, np.newaxis]).prod(0)
return p_arr[:-m] * coeff[::-1]
_LEADING_ZEROS_DOC = """\
@ -562,6 +573,7 @@ JAX backends. The result may lead to inconsistent output shapes when trim_leadin
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
check_arraylike("polymul", a1, a2)
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
del a1, a2
if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1):
a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f')
if len(a1_arr) == 0:
@ -574,6 +586,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
check_arraylike("polydiv", u, v)
u_arr, v_arr = promote_dtypes_inexact(u, v)
del u, v
m = len(u_arr) - 1
n = len(v_arr) - 1
scale = 1. / v_arr[0]
@ -589,7 +602,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) ->
@implements(np.polysub)
@jit
def polysub(a1: Array, a2: Array) -> Array:
def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
check_arraylike("polysub", a1, a2)
a1, a2 = promote_dtypes(a1, a2)
return polyadd(a1, -a2)

View File

@ -669,17 +669,17 @@ def piecewise(x: ArrayLike, condlist: Union[Array, Sequence[ArrayLike]],
*args, **kw) -> Array: ...
def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
inplace: builtins.bool = ...) -> Array: ...
def poly(seq_of_zeros: Array) -> Array: ...
def polyadd(a1: Array, a2: Array) -> Array: ...
def polyder(p: Array, m: int = ...) -> Array: ...
def poly(seq_of_zeros: ArrayLike) -> Array: ...
def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: ...
def polyder(p: ArrayLike, m: int = ...) -> Array: ...
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> tuple[Array, Array]: ...
def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = ...,
full: builtins.bool = ..., w: Optional[Array] = ..., cov: builtins.bool = ...
def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: Optional[float] = ...,
full: builtins.bool = ..., w: Optional[ArrayLike] = ..., cov: builtins.bool = ...
) -> Union[Array, tuple[Array, ...]]: ...
def polyint(p: Array, m: int = ..., k: Optional[int] = ...) -> Array: ...
def polyint(p: ArrayLike, m: int = ..., k: Union[int, ArrayLike, None] = ...) -> Array: ...
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: builtins.bool = ...) -> Array: ...
def polysub(a1: Array, a2: Array) -> Array: ...
def polyval(p: Array, x: Array, *, unroll: int = ...) -> Array: ...
def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: ...
def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = ...) -> Array: ...
def positive(x: ArrayLike, /) -> Array: ...
def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def power(x: ArrayLike, y: ArrayLike, /) -> Array: ...