mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[typing] add types for jax.numpy.polynomial
This commit is contained in:
parent
dd0a455a78
commit
6a348f9666
@ -15,6 +15,7 @@
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from jax import core
|
||||
from jax import jit
|
||||
@ -26,11 +27,12 @@ from jax._src.numpy.lax_numpy import (
|
||||
vander, zeros)
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
import numpy as np
|
||||
|
||||
|
||||
@jit
|
||||
def _roots_no_zeros(p):
|
||||
def _roots_no_zeros(p: Array) -> Array:
|
||||
# build companion matrix and find its eigenvalues (the roots)
|
||||
if p.size < 2:
|
||||
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
|
||||
@ -40,7 +42,7 @@ def _roots_no_zeros(p):
|
||||
|
||||
|
||||
@jit
|
||||
def _roots_with_zeros(p, num_leading_zeros):
|
||||
def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
|
||||
# Avoid lapack errors when p is all zero
|
||||
p = _where(len(p) == num_leading_zeros, 1.0, p)
|
||||
# Roll any leading zeros to the end & compute the roots
|
||||
@ -77,23 +79,23 @@ strip_zeros : bool, default=True
|
||||
``strip_zeros`` must be set to ``False`` for the function to be compatible with
|
||||
:func:`jax.jit` and other JAX transformations.
|
||||
""")
|
||||
def roots(p, *, strip_zeros=True):
|
||||
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
|
||||
_check_arraylike("roots", p)
|
||||
p = atleast_1d(*_promote_dtypes_inexact(p))
|
||||
if p.ndim != 1:
|
||||
p_arr = atleast_1d(*_promote_dtypes_inexact(p))
|
||||
if p_arr.ndim != 1:
|
||||
raise ValueError("Input must be a rank-1 array.")
|
||||
if p.size < 2:
|
||||
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
|
||||
num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))
|
||||
if p_arr.size < 2:
|
||||
return array([], dtype=dtypes.to_complex_dtype(p_arr.dtype))
|
||||
num_leading_zeros = _where(all(p_arr == 0), len(p_arr), argmin(p_arr == 0))
|
||||
|
||||
if strip_zeros:
|
||||
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
|
||||
"The error occurred in the jnp.roots() function. To use this within a "
|
||||
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
|
||||
"will be result in some returned roots being set to NaN.")
|
||||
return _roots_no_zeros(p[num_leading_zeros:])
|
||||
return _roots_no_zeros(p_arr[num_leading_zeros:])
|
||||
else:
|
||||
return _roots_with_zeros(p, num_leading_zeros)
|
||||
return _roots_with_zeros(p_arr, num_leading_zeros)
|
||||
|
||||
|
||||
_POLYFIT_DOC = """\
|
||||
@ -102,7 +104,9 @@ Also, it works best on rcond <= 10e-3 values.
|
||||
"""
|
||||
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
|
||||
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
|
||||
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
|
||||
def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None,
|
||||
full: bool = False, w: Optional[Array] = None, cov: bool = False
|
||||
) -> Union[Array, Tuple[Array, ...]]:
|
||||
_check_arraylike("polyfit", x, y)
|
||||
deg = core.concrete_or_error(int, deg, "deg must be int")
|
||||
order = deg + 1
|
||||
@ -147,7 +151,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
|
||||
c = (c.T/scale).T # broadcast scale coefficients
|
||||
|
||||
if full:
|
||||
return c, resids, rank, s, rcond
|
||||
return c, resids, rank, s, asarray(rcond)
|
||||
elif cov:
|
||||
Vbase = linalg.inv(dot(lhs.T, lhs))
|
||||
Vbase /= outer(scale, scale)
|
||||
@ -181,7 +185,7 @@ jax returns an array with a complex dtype in such cases.
|
||||
|
||||
@_wraps(np.poly, lax_description=_POLY_DOC)
|
||||
@jit
|
||||
def poly(seq_of_zeros):
|
||||
def poly(seq_of_zeros: Array) -> Array:
|
||||
_check_arraylike('poly', seq_of_zeros)
|
||||
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
|
||||
seq_of_zeros = atleast_1d(seq_of_zeros)
|
||||
@ -215,7 +219,7 @@ improve runtime performance on accelerators, at the cost of increased
|
||||
compilation time.
|
||||
""")
|
||||
@partial(jit, static_argnames=['unroll'])
|
||||
def polyval(p, x, *, unroll=16):
|
||||
def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
|
||||
_check_arraylike("polyval", p, x)
|
||||
p, x = _promote_dtypes_inexact(p, x)
|
||||
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
|
||||
@ -225,7 +229,7 @@ def polyval(p, x, *, unroll=16):
|
||||
|
||||
@_wraps(np.polyadd)
|
||||
@jit
|
||||
def polyadd(a1, a2):
|
||||
def polyadd(a1: Array, a2: Array) -> Array:
|
||||
_check_arraylike("polyadd", a1, a2)
|
||||
a1, a2 = _promote_dtypes(a1, a2)
|
||||
if a2.shape[0] <= a1.shape[0]:
|
||||
@ -236,17 +240,17 @@ def polyadd(a1, a2):
|
||||
|
||||
@_wraps(np.polyint)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyint(p, m=1, k=None):
|
||||
def polyint(p: Array, m: int = 1, k: Optional[int] = None) -> Array:
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
|
||||
k = 0 if k is None else k
|
||||
_check_arraylike("polyint", p, k)
|
||||
p, k = _promote_dtypes_inexact(p, k)
|
||||
p, k_arr = _promote_dtypes_inexact(p, k)
|
||||
if m < 0:
|
||||
raise ValueError("Order of integral must be positive (see polyder)")
|
||||
k = atleast_1d(k)
|
||||
if len(k) == 1:
|
||||
k = full((m,), k[0])
|
||||
if k.shape != (m,):
|
||||
k_arr = atleast_1d(k_arr)
|
||||
if len(k_arr) == 1:
|
||||
k_arr = full((m,), k_arr[0])
|
||||
if k_arr.shape != (m,):
|
||||
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
|
||||
if m == 0:
|
||||
return p
|
||||
@ -254,12 +258,12 @@ def polyint(p, m=1, k=None):
|
||||
grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis]
|
||||
- arange(m, dtype=p.dtype)[:, np.newaxis])
|
||||
coeff = maximum(1, grid).prod(0)[::-1]
|
||||
return true_divide(concatenate((p, k)), coeff)
|
||||
return true_divide(concatenate((p, k_arr)), coeff)
|
||||
|
||||
|
||||
@_wraps(np.polyder)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyder(p, m=1):
|
||||
def polyder(p: Array, m: int = 1) -> Array:
|
||||
_check_arraylike("polyder", p)
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
|
||||
p, = _promote_dtypes_inexact(p)
|
||||
@ -281,38 +285,37 @@ JAX backends. The result may lead to inconsistent output shapes when trim_leadin
|
||||
"""
|
||||
|
||||
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polymul(a1, a2, *, trim_leading_zeros=False):
|
||||
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
|
||||
_check_arraylike("polymul", a1, a2)
|
||||
a1, a2 = _promote_dtypes_inexact(a1, a2)
|
||||
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
|
||||
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
|
||||
if len(a1) == 0:
|
||||
a1 = asarray([0], dtype=a2.dtype)
|
||||
if len(a2) == 0:
|
||||
a2 = asarray([0], dtype=a1.dtype)
|
||||
return convolve(a1, a2, mode='full')
|
||||
a1_arr, a2_arr = _promote_dtypes_inexact(a1, a2)
|
||||
if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1):
|
||||
a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f')
|
||||
if len(a1_arr) == 0:
|
||||
a1_arr = asarray([0], dtype=a2_arr.dtype)
|
||||
if len(a2_arr) == 0:
|
||||
a2_arr = asarray([0], dtype=a1_arr.dtype)
|
||||
return convolve(a1_arr, a2_arr, mode='full')
|
||||
|
||||
@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polydiv(u, v, *, trim_leading_zeros=False):
|
||||
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> Tuple[Array, Array]:
|
||||
_check_arraylike("polydiv", u, v)
|
||||
u, v = _promote_dtypes_inexact(u, v)
|
||||
m = len(u) - 1
|
||||
n = len(v) - 1
|
||||
scale = 1. / v[0]
|
||||
q = zeros(max(m - n + 1, 1), dtype = u.dtype) # force same dtype
|
||||
u_arr, v_arr = _promote_dtypes_inexact(u, v)
|
||||
m = len(u_arr) - 1
|
||||
n = len(v_arr) - 1
|
||||
scale = 1. / v_arr[0]
|
||||
q: Array = zeros(max(m - n + 1, 1), dtype = u_arr.dtype) # force same dtype
|
||||
for k in range(0, m-n+1):
|
||||
d = scale * u[k]
|
||||
d = scale * u_arr[k]
|
||||
q = q.at[k].set(d)
|
||||
u = u.at[k:k+n+1].add(-d*v)
|
||||
u_arr = u_arr.at[k:k+n+1].add(-d*v_arr)
|
||||
if trim_leading_zeros:
|
||||
# use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
|
||||
return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f')
|
||||
else:
|
||||
return q, u
|
||||
u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f')
|
||||
return q, u_arr
|
||||
|
||||
@_wraps(np.polysub)
|
||||
@jit
|
||||
def polysub(a1, a2):
|
||||
def polysub(a1: Array, a2: Array) -> Array:
|
||||
_check_arraylike("polysub", a1, a2)
|
||||
a1, a2 = _promote_dtypes(a1, a2)
|
||||
return polyadd(a1, -a2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user