lax_numpy: move poly functions into numpy.polynomial

This commit is contained in:
Jake VanderPlas 2022-03-17 13:28:54 -07:00
parent 2d79a6462f
commit 603bb3c5ca
3 changed files with 233 additions and 231 deletions

View File

@ -702,76 +702,6 @@ def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None,
def isrealobj(x):
return not iscomplexobj(x)
_POLYFIT_DOC = """\
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
Also, it works best on rcond <= 10e-3 values.
"""
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
_check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
if deg < 0:
raise ValueError("expected deg >= 0")
if x.ndim != 1:
raise TypeError("expected 1D vector for x")
if x.size == 0:
raise TypeError("expected non-empty vector for x")
if y.ndim < 1 or y.ndim > 2:
raise TypeError("expected 1D or 2D array for y")
if x.shape[0] != y.shape[0]:
raise TypeError("expected x and y to have same length")
# set rcond
if rcond is None:
rcond = len(x)*finfo(x.dtype).eps
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
# set up least squares equation for powers of x
lhs = vander(x, order)
rhs = y
# apply weighting
if w is not None:
_check_arraylike("polyfit", w)
w, = _promote_dtypes_inexact(w)
if w.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w.shape[0] != y.shape[0]:
raise TypeError("expected w and y to have the same length")
lhs *= w[:, newaxis]
if rhs.ndim == 2:
rhs *= w[:, newaxis]
else:
rhs *= w
# scale lhs to improve condition number and solve
scale = sqrt((lhs*lhs).sum(axis=0))
lhs /= scale[newaxis,:]
from jax._src.numpy import linalg
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
c = (c.T/scale).T # broadcast scale coefficients
if full:
return c, resids, rank, s, rcond
elif cov:
Vbase = linalg.inv(dot(lhs.T, lhs))
Vbase /= outer(scale, scale)
if cov == "unscaled":
fac = 1
else:
if len(x) <= order:
raise ValueError("the number of data points must exceed order "
"to scale the covariance matrix")
fac = resids / (len(x) - order)
fac = fac[0] #making np.array() of shape (1,) to int
if y.ndim == 1:
return c, Vbase * fac
else:
return c, Vbase[:,:, newaxis] * fac
else:
return c
@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
@ -3242,107 +3172,6 @@ def diagflat(v, k=0):
res = res.reshape(adj_length, adj_length)
return res
_POLY_DOC = """\
This differs from np.poly when an integer array is given.
np.poly returns a result with dtype float64 in this case.
jax returns a result with an inexact type, but not necessarily
float64.
This also differs from np.poly when the input array strictly
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
np.poly returns an array with a real dtype in such cases.
jax returns an array with a complex dtype in such cases.
"""
@_wraps(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros):
_check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
sh = seq_of_zeros.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from jax._src.numpy import linalg
seq_of_zeros = linalg.eigvals(seq_of_zeros)
if seq_of_zeros.ndim != 1:
raise ValueError("input must be 1d or non-empty square 2d array.")
dt = seq_of_zeros.dtype
if len(seq_of_zeros) == 0:
return ones((), dtype=dt)
a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros)):
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
return a
@_wraps(np.polyval, lax_description="""\
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
have a major impact on performance for evaluating high-order polynomials. The
parameter controls the number of unrolled steps with ``lax.scan`` inside the
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
improve runtime performance on accelerators, at the cost of increased
compilation time.
""")
@partial(jax.jit, static_argnames=['unroll'])
def polyval(p, x, *, unroll=16):
_check_arraylike("polyval", p, x)
p, x = _promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y
@_wraps(np.polyadd)
@jit
def polyadd(a1, a2):
_check_arraylike("polyadd", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
if a2.shape[0] <= a1.shape[0]:
return a1.at[-a2.shape[0]:].add(a2)
else:
return a2.at[-a1.shape[0]:].add(a1)
@_wraps(np.polyint)
@partial(jit, static_argnames=('m',))
def polyint(p, m=1, k=None):
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
_check_arraylike("polyint", p, k)
p, k = _promote_dtypes_inexact(p, k)
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k = atleast_1d(k)
if len(k) == 1:
k = full((m,), k[0])
if k.shape != (m,):
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p
else:
coeff = maximum(1, arange(len(p) + m, 0, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
return true_divide(concatenate((p, k)), coeff)
@_wraps(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p, m=1):
_check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = _promote_dtypes_inexact(p)
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
return p
coeff = (arange(len(p), m, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
return p[:-m] * coeff
@_wraps(np.trim_zeros)
def trim_zeros(filt, trim='fb'):
@ -3356,33 +3185,6 @@ def trim_zeros(filt, trim='fb'):
return filt[start:len(filt) - end]
_LEADING_ZEROS_DOC = """\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
"""
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1, a2, *, trim_leading_zeros=False):
_check_arraylike("polymul", a1, a2)
a1, a2 = _promote_dtypes_inexact(a1, a2)
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
if len(a1) == 0:
a1 = asarray([0.])
if len(a2) == 0:
a2 = asarray([0.])
val = convolve(a1, a2, mode='full')
return val
@_wraps(np.polysub)
@jit
def polysub(a1, a2):
_check_arraylike("polysub", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
return polyadd(a1, -a2)
@_wraps(np.append)
@partial(jit, static_argnames=('axis',))
def append(arr, values, axis: Optional[int] = None):

View File

@ -13,32 +13,29 @@
# limitations under the License.
import numpy as np
from jax import lax
from jax._src.numpy import lax_numpy as jnp
from functools import partial
import operator
from jax import core
from jax import jit
from jax._src.numpy.util import _wraps
from jax._src.numpy.linalg import eigvals as _eigvals
def _to_inexact_type(type):
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
def _promote_inexact(arr):
return lax.convert_element_type(arr, _to_inexact_type(arr.dtype))
from jax import lax
from jax._src.numpy.lax_numpy import (
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo,
full, hstack, maximum, ones, outer, sqrt, trim_zeros, true_divide, vander, zeros)
from jax._src.numpy import linalg
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps
import numpy as np
@jit
def _roots_no_zeros(p):
# assume: p does not have leading zeros and has length > 1
p = _promote_inexact(p)
p, = _promote_dtypes_inexact(p)
# build companion matrix and find its eigenvalues (the roots)
A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1)
A = diag(ones((p.size - 2,), p.dtype), -1)
A = A.at[0, :].set(-p[1:] / p[0])
roots = _eigvals(A)
roots = linalg.eigvals(A)
return roots
@ -46,8 +43,8 @@ def _roots_no_zeros(p):
def _nonzero_range(arr):
# return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros
is_zero = arr == 0
start = jnp.argmin(is_zero)
end = is_zero.size - jnp.argmin(is_zero[::-1])
start = argmin(is_zero)
end = is_zero.size - argmin(is_zero[::-1])
return start, end
@ -73,7 +70,7 @@ DeviceArray([-2.+0.j], dtype=complex64)
""")
def roots(p, *, strip_zeros=True):
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
p = jnp.atleast_1d(p)
p = atleast_1d(p)
if p.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
@ -82,10 +79,10 @@ def roots(p, *, strip_zeros=True):
if p.size > 1:
return _roots_no_zeros(p)
else:
return jnp.array([])
return array([])
if jnp.all(p == 0):
return jnp.array([])
if all(p == 0):
return array([])
# factor out trivial roots
start, end = _nonzero_range(p)
@ -96,9 +93,209 @@ def roots(p, *, strip_zeros=True):
p = p[start:end]
if p.size < 2:
return jnp.zeros(trailing_zeros, p.dtype)
return zeros(trailing_zeros, p.dtype)
else:
roots = _roots_no_zeros(p)
# combine roots and zero roots
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
roots = hstack((roots, zeros(trailing_zeros, p.dtype)))
return roots
_POLYFIT_DOC = """\
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
Also, it works best on rcond <= 10e-3 values.
"""
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
_check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
if deg < 0:
raise ValueError("expected deg >= 0")
if x.ndim != 1:
raise TypeError("expected 1D vector for x")
if x.size == 0:
raise TypeError("expected non-empty vector for x")
if y.ndim < 1 or y.ndim > 2:
raise TypeError("expected 1D or 2D array for y")
if x.shape[0] != y.shape[0]:
raise TypeError("expected x and y to have same length")
# set rcond
if rcond is None:
rcond = len(x) * finfo(x.dtype).eps
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
# set up least squares equation for powers of x
lhs = vander(x, order)
rhs = y
# apply weighting
if w is not None:
_check_arraylike("polyfit", w)
w, = _promote_dtypes_inexact(w)
if w.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w.shape[0] != y.shape[0]:
raise TypeError("expected w and y to have the same length")
lhs *= w[:, np.newaxis]
if rhs.ndim == 2:
rhs *= w[:, np.newaxis]
else:
rhs *= w
# scale lhs to improve condition number and solve
scale = sqrt((lhs*lhs).sum(axis=0))
lhs /= scale[np.newaxis,:]
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
c = (c.T/scale).T # broadcast scale coefficients
if full:
return c, resids, rank, s, rcond
elif cov:
Vbase = linalg.inv(dot(lhs.T, lhs))
Vbase /= outer(scale, scale)
if cov == "unscaled":
fac = 1
else:
if len(x) <= order:
raise ValueError("the number of data points must exceed order "
"to scale the covariance matrix")
fac = resids / (len(x) - order)
fac = fac[0] #making np.array() of shape (1,) to int
if y.ndim == 1:
return c, Vbase * fac
else:
return c, Vbase[:, :, np.newaxis] * fac
else:
return c
_POLY_DOC = """\
This differs from np.poly when an integer array is given.
np.poly returns a result with dtype float64 in this case.
jax returns a result with an inexact type, but not necessarily
float64.
This also differs from np.poly when the input array strictly
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
np.poly returns an array with a real dtype in such cases.
jax returns an array with a complex dtype in such cases.
"""
@_wraps(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros):
_check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
sh = seq_of_zeros.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from jax._src.numpy import linalg
seq_of_zeros = linalg.eigvals(seq_of_zeros)
if seq_of_zeros.ndim != 1:
raise ValueError("input must be 1d or non-empty square 2d array.")
dt = seq_of_zeros.dtype
if len(seq_of_zeros) == 0:
return ones((), dtype=dt)
a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros)):
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
return a
@_wraps(np.polyval, lax_description="""\
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
have a major impact on performance for evaluating high-order polynomials. The
parameter controls the number of unrolled steps with ``lax.scan`` inside the
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
improve runtime performance on accelerators, at the cost of increased
compilation time.
""")
@partial(jit, static_argnames=['unroll'])
def polyval(p, x, *, unroll=16):
_check_arraylike("polyval", p, x)
p, x = _promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y
@_wraps(np.polyadd)
@jit
def polyadd(a1, a2):
_check_arraylike("polyadd", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
if a2.shape[0] <= a1.shape[0]:
return a1.at[-a2.shape[0]:].add(a2)
else:
return a2.at[-a1.shape[0]:].add(a1)
@_wraps(np.polyint)
@partial(jit, static_argnames=('m',))
def polyint(p, m=1, k=None):
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
_check_arraylike("polyint", p, k)
p, k = _promote_dtypes_inexact(p, k)
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k = atleast_1d(k)
if len(k) == 1:
k = full((m,), k[0])
if k.shape != (m,):
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p
else:
coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
return true_divide(concatenate((p, k)), coeff)
@_wraps(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p, m=1):
_check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = _promote_dtypes_inexact(p)
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
return p
coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
return p[:-m] * coeff
_LEADING_ZEROS_DOC = """\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
"""
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1, a2, *, trim_leading_zeros=False):
_check_arraylike("polymul", a1, a2)
a1, a2 = _promote_dtypes_inexact(a1, a2)
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
if len(a1) == 0:
a1 = asarray([0.])
if len(a2) == 0:
a2 = asarray([0.])
val = convolve(a1, a2, mode='full')
return val
@_wraps(np.polysub)
@jit
def polysub(a1, a2):
_check_arraylike("polysub", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
return polyadd(a1, -a2)

View File

@ -211,14 +211,6 @@ from jax._src.numpy.lax_numpy import (
percentile as percentile,
pi as pi,
piecewise as piecewise,
poly as poly,
polyadd as polyadd,
polyder as polyder,
polyfit as polyfit,
polyint as polyint,
polymul as polymul,
polysub as polysub,
polyval as polyval,
printoptions as printoptions,
prod as prod,
product as product,
@ -304,6 +296,18 @@ from jax._src.numpy.index_tricks import (
s_ as s_,
)
from jax._src.numpy.polynomial import (
poly as poly,
polyadd as polyadd,
polyder as polyder,
polyfit as polyfit,
polyint as polyint,
polymul as polymul,
polysub as polysub,
polyval as polyval,
roots as roots,
)
from jax._src.numpy.ufuncs import (
abs as abs,
absolute as absolute,
@ -395,7 +399,6 @@ from jax._src.numpy.ufuncs import (
true_divide as true_divide,
)
from jax._src.numpy.polynomial import roots as roots
from jax._src.numpy.vectorize import vectorize as vectorize
# TODO(phawkins): remove this import after fixing users.