mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
lax_numpy: move poly functions into numpy.polynomial
This commit is contained in:
parent
2d79a6462f
commit
603bb3c5ca
@ -702,76 +702,6 @@ def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
def isrealobj(x):
|
||||
return not iscomplexobj(x)
|
||||
|
||||
_POLYFIT_DOC = """\
|
||||
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
|
||||
Also, it works best on rcond <= 10e-3 values.
|
||||
"""
|
||||
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
|
||||
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
|
||||
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
|
||||
_check_arraylike("polyfit", x, y)
|
||||
deg = core.concrete_or_error(int, deg, "deg must be int")
|
||||
order = deg + 1
|
||||
# check arguments
|
||||
if deg < 0:
|
||||
raise ValueError("expected deg >= 0")
|
||||
if x.ndim != 1:
|
||||
raise TypeError("expected 1D vector for x")
|
||||
if x.size == 0:
|
||||
raise TypeError("expected non-empty vector for x")
|
||||
if y.ndim < 1 or y.ndim > 2:
|
||||
raise TypeError("expected 1D or 2D array for y")
|
||||
if x.shape[0] != y.shape[0]:
|
||||
raise TypeError("expected x and y to have same length")
|
||||
|
||||
# set rcond
|
||||
if rcond is None:
|
||||
rcond = len(x)*finfo(x.dtype).eps
|
||||
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
|
||||
# set up least squares equation for powers of x
|
||||
lhs = vander(x, order)
|
||||
rhs = y
|
||||
|
||||
# apply weighting
|
||||
if w is not None:
|
||||
_check_arraylike("polyfit", w)
|
||||
w, = _promote_dtypes_inexact(w)
|
||||
if w.ndim != 1:
|
||||
raise TypeError("expected a 1-d array for weights")
|
||||
if w.shape[0] != y.shape[0]:
|
||||
raise TypeError("expected w and y to have the same length")
|
||||
lhs *= w[:, newaxis]
|
||||
if rhs.ndim == 2:
|
||||
rhs *= w[:, newaxis]
|
||||
else:
|
||||
rhs *= w
|
||||
|
||||
# scale lhs to improve condition number and solve
|
||||
scale = sqrt((lhs*lhs).sum(axis=0))
|
||||
lhs /= scale[newaxis,:]
|
||||
from jax._src.numpy import linalg
|
||||
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
|
||||
c = (c.T/scale).T # broadcast scale coefficients
|
||||
|
||||
if full:
|
||||
return c, resids, rank, s, rcond
|
||||
elif cov:
|
||||
Vbase = linalg.inv(dot(lhs.T, lhs))
|
||||
Vbase /= outer(scale, scale)
|
||||
if cov == "unscaled":
|
||||
fac = 1
|
||||
else:
|
||||
if len(x) <= order:
|
||||
raise ValueError("the number of data points must exceed order "
|
||||
"to scale the covariance matrix")
|
||||
fac = resids / (len(x) - order)
|
||||
fac = fac[0] #making np.array() of shape (1,) to int
|
||||
if y.ndim == 1:
|
||||
return c, Vbase * fac
|
||||
else:
|
||||
return c, Vbase[:,:, newaxis] * fac
|
||||
else:
|
||||
return c
|
||||
|
||||
|
||||
@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
|
||||
@ -3242,107 +3172,6 @@ def diagflat(v, k=0):
|
||||
res = res.reshape(adj_length, adj_length)
|
||||
return res
|
||||
|
||||
_POLY_DOC = """\
|
||||
This differs from np.poly when an integer array is given.
|
||||
np.poly returns a result with dtype float64 in this case.
|
||||
jax returns a result with an inexact type, but not necessarily
|
||||
float64.
|
||||
|
||||
This also differs from np.poly when the input array strictly
|
||||
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
|
||||
np.poly returns an array with a real dtype in such cases.
|
||||
jax returns an array with a complex dtype in such cases.
|
||||
"""
|
||||
|
||||
@_wraps(np.poly, lax_description=_POLY_DOC)
|
||||
@jit
|
||||
def poly(seq_of_zeros):
|
||||
_check_arraylike('poly', seq_of_zeros)
|
||||
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
|
||||
seq_of_zeros = atleast_1d(seq_of_zeros)
|
||||
|
||||
sh = seq_of_zeros.shape
|
||||
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
|
||||
# import at runtime to avoid circular import
|
||||
from jax._src.numpy import linalg
|
||||
seq_of_zeros = linalg.eigvals(seq_of_zeros)
|
||||
|
||||
if seq_of_zeros.ndim != 1:
|
||||
raise ValueError("input must be 1d or non-empty square 2d array.")
|
||||
|
||||
dt = seq_of_zeros.dtype
|
||||
if len(seq_of_zeros) == 0:
|
||||
return ones((), dtype=dt)
|
||||
|
||||
a = ones((1,), dtype=dt)
|
||||
for k in range(len(seq_of_zeros)):
|
||||
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@_wraps(np.polyval, lax_description="""\
|
||||
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
|
||||
have a major impact on performance for evaluating high-order polynomials. The
|
||||
parameter controls the number of unrolled steps with ``lax.scan`` inside the
|
||||
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
|
||||
improve runtime performance on accelerators, at the cost of increased
|
||||
compilation time.
|
||||
""")
|
||||
@partial(jax.jit, static_argnames=['unroll'])
|
||||
def polyval(p, x, *, unroll=16):
|
||||
_check_arraylike("polyval", p, x)
|
||||
p, x = _promote_dtypes_inexact(p, x)
|
||||
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
|
||||
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
|
||||
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
|
||||
return y
|
||||
|
||||
@_wraps(np.polyadd)
|
||||
@jit
|
||||
def polyadd(a1, a2):
|
||||
_check_arraylike("polyadd", a1, a2)
|
||||
a1, a2 = _promote_dtypes(a1, a2)
|
||||
if a2.shape[0] <= a1.shape[0]:
|
||||
return a1.at[-a2.shape[0]:].add(a2)
|
||||
else:
|
||||
return a2.at[-a1.shape[0]:].add(a1)
|
||||
|
||||
|
||||
@_wraps(np.polyint)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyint(p, m=1, k=None):
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
|
||||
k = 0 if k is None else k
|
||||
_check_arraylike("polyint", p, k)
|
||||
p, k = _promote_dtypes_inexact(p, k)
|
||||
if m < 0:
|
||||
raise ValueError("Order of integral must be positive (see polyder)")
|
||||
k = atleast_1d(k)
|
||||
if len(k) == 1:
|
||||
k = full((m,), k[0])
|
||||
if k.shape != (m,):
|
||||
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
|
||||
if m == 0:
|
||||
return p
|
||||
else:
|
||||
coeff = maximum(1, arange(len(p) + m, 0, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
|
||||
return true_divide(concatenate((p, k)), coeff)
|
||||
|
||||
|
||||
@_wraps(np.polyder)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyder(p, m=1):
|
||||
_check_arraylike("polyder", p)
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
|
||||
p, = _promote_dtypes_inexact(p)
|
||||
if m < 0:
|
||||
raise ValueError("Order of derivative must be positive")
|
||||
if m == 0:
|
||||
return p
|
||||
coeff = (arange(len(p), m, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
|
||||
return p[:-m] * coeff
|
||||
|
||||
|
||||
@_wraps(np.trim_zeros)
|
||||
def trim_zeros(filt, trim='fb'):
|
||||
@ -3356,33 +3185,6 @@ def trim_zeros(filt, trim='fb'):
|
||||
return filt[start:len(filt) - end]
|
||||
|
||||
|
||||
_LEADING_ZEROS_DOC = """\
|
||||
Setting trim_leading_zeros=True makes the output match that of numpy.
|
||||
But prevents the function from being able to be used in compiled code.
|
||||
"""
|
||||
|
||||
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polymul(a1, a2, *, trim_leading_zeros=False):
|
||||
_check_arraylike("polymul", a1, a2)
|
||||
a1, a2 = _promote_dtypes_inexact(a1, a2)
|
||||
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
|
||||
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
|
||||
if len(a1) == 0:
|
||||
a1 = asarray([0.])
|
||||
if len(a2) == 0:
|
||||
a2 = asarray([0.])
|
||||
val = convolve(a1, a2, mode='full')
|
||||
return val
|
||||
|
||||
|
||||
@_wraps(np.polysub)
|
||||
@jit
|
||||
def polysub(a1, a2):
|
||||
_check_arraylike("polysub", a1, a2)
|
||||
a1, a2 = _promote_dtypes(a1, a2)
|
||||
return polyadd(a1, -a2)
|
||||
|
||||
|
||||
@_wraps(np.append)
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def append(arr, values, axis: Optional[int] = None):
|
||||
|
@ -13,32 +13,29 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
from jax import lax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
from jax import core
|
||||
from jax import jit
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.linalg import eigvals as _eigvals
|
||||
|
||||
|
||||
def _to_inexact_type(type):
|
||||
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
|
||||
|
||||
|
||||
def _promote_inexact(arr):
|
||||
return lax.convert_element_type(arr, _to_inexact_type(arr.dtype))
|
||||
from jax import lax
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
all, arange, argmin, array, asarray, atleast_1d, concatenate, convolve, diag, dot, finfo,
|
||||
full, hstack, maximum, ones, outer, sqrt, trim_zeros, true_divide, vander, zeros)
|
||||
from jax._src.numpy import linalg
|
||||
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _wraps
|
||||
import numpy as np
|
||||
|
||||
|
||||
@jit
|
||||
def _roots_no_zeros(p):
|
||||
# assume: p does not have leading zeros and has length > 1
|
||||
p = _promote_inexact(p)
|
||||
p, = _promote_dtypes_inexact(p)
|
||||
|
||||
# build companion matrix and find its eigenvalues (the roots)
|
||||
A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1)
|
||||
A = diag(ones((p.size - 2,), p.dtype), -1)
|
||||
A = A.at[0, :].set(-p[1:] / p[0])
|
||||
roots = _eigvals(A)
|
||||
roots = linalg.eigvals(A)
|
||||
return roots
|
||||
|
||||
|
||||
@ -46,8 +43,8 @@ def _roots_no_zeros(p):
|
||||
def _nonzero_range(arr):
|
||||
# return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros
|
||||
is_zero = arr == 0
|
||||
start = jnp.argmin(is_zero)
|
||||
end = is_zero.size - jnp.argmin(is_zero[::-1])
|
||||
start = argmin(is_zero)
|
||||
end = is_zero.size - argmin(is_zero[::-1])
|
||||
return start, end
|
||||
|
||||
|
||||
@ -73,7 +70,7 @@ DeviceArray([-2.+0.j], dtype=complex64)
|
||||
""")
|
||||
def roots(p, *, strip_zeros=True):
|
||||
# ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
|
||||
p = jnp.atleast_1d(p)
|
||||
p = atleast_1d(p)
|
||||
if p.ndim != 1:
|
||||
raise ValueError("Input must be a rank-1 array.")
|
||||
|
||||
@ -82,10 +79,10 @@ def roots(p, *, strip_zeros=True):
|
||||
if p.size > 1:
|
||||
return _roots_no_zeros(p)
|
||||
else:
|
||||
return jnp.array([])
|
||||
return array([])
|
||||
|
||||
if jnp.all(p == 0):
|
||||
return jnp.array([])
|
||||
if all(p == 0):
|
||||
return array([])
|
||||
|
||||
# factor out trivial roots
|
||||
start, end = _nonzero_range(p)
|
||||
@ -96,9 +93,209 @@ def roots(p, *, strip_zeros=True):
|
||||
p = p[start:end]
|
||||
|
||||
if p.size < 2:
|
||||
return jnp.zeros(trailing_zeros, p.dtype)
|
||||
return zeros(trailing_zeros, p.dtype)
|
||||
else:
|
||||
roots = _roots_no_zeros(p)
|
||||
# combine roots and zero roots
|
||||
roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
|
||||
roots = hstack((roots, zeros(trailing_zeros, p.dtype)))
|
||||
return roots
|
||||
|
||||
|
||||
_POLYFIT_DOC = """\
|
||||
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
|
||||
Also, it works best on rcond <= 10e-3 values.
|
||||
"""
|
||||
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
|
||||
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
|
||||
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
|
||||
_check_arraylike("polyfit", x, y)
|
||||
deg = core.concrete_or_error(int, deg, "deg must be int")
|
||||
order = deg + 1
|
||||
# check arguments
|
||||
if deg < 0:
|
||||
raise ValueError("expected deg >= 0")
|
||||
if x.ndim != 1:
|
||||
raise TypeError("expected 1D vector for x")
|
||||
if x.size == 0:
|
||||
raise TypeError("expected non-empty vector for x")
|
||||
if y.ndim < 1 or y.ndim > 2:
|
||||
raise TypeError("expected 1D or 2D array for y")
|
||||
if x.shape[0] != y.shape[0]:
|
||||
raise TypeError("expected x and y to have same length")
|
||||
|
||||
# set rcond
|
||||
if rcond is None:
|
||||
rcond = len(x) * finfo(x.dtype).eps
|
||||
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
|
||||
# set up least squares equation for powers of x
|
||||
lhs = vander(x, order)
|
||||
rhs = y
|
||||
|
||||
# apply weighting
|
||||
if w is not None:
|
||||
_check_arraylike("polyfit", w)
|
||||
w, = _promote_dtypes_inexact(w)
|
||||
if w.ndim != 1:
|
||||
raise TypeError("expected a 1-d array for weights")
|
||||
if w.shape[0] != y.shape[0]:
|
||||
raise TypeError("expected w and y to have the same length")
|
||||
lhs *= w[:, np.newaxis]
|
||||
if rhs.ndim == 2:
|
||||
rhs *= w[:, np.newaxis]
|
||||
else:
|
||||
rhs *= w
|
||||
|
||||
# scale lhs to improve condition number and solve
|
||||
scale = sqrt((lhs*lhs).sum(axis=0))
|
||||
lhs /= scale[np.newaxis,:]
|
||||
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
|
||||
c = (c.T/scale).T # broadcast scale coefficients
|
||||
|
||||
if full:
|
||||
return c, resids, rank, s, rcond
|
||||
elif cov:
|
||||
Vbase = linalg.inv(dot(lhs.T, lhs))
|
||||
Vbase /= outer(scale, scale)
|
||||
if cov == "unscaled":
|
||||
fac = 1
|
||||
else:
|
||||
if len(x) <= order:
|
||||
raise ValueError("the number of data points must exceed order "
|
||||
"to scale the covariance matrix")
|
||||
fac = resids / (len(x) - order)
|
||||
fac = fac[0] #making np.array() of shape (1,) to int
|
||||
if y.ndim == 1:
|
||||
return c, Vbase * fac
|
||||
else:
|
||||
return c, Vbase[:, :, np.newaxis] * fac
|
||||
else:
|
||||
return c
|
||||
|
||||
|
||||
_POLY_DOC = """\
|
||||
This differs from np.poly when an integer array is given.
|
||||
np.poly returns a result with dtype float64 in this case.
|
||||
jax returns a result with an inexact type, but not necessarily
|
||||
float64.
|
||||
|
||||
This also differs from np.poly when the input array strictly
|
||||
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
|
||||
np.poly returns an array with a real dtype in such cases.
|
||||
jax returns an array with a complex dtype in such cases.
|
||||
"""
|
||||
|
||||
@_wraps(np.poly, lax_description=_POLY_DOC)
|
||||
@jit
|
||||
def poly(seq_of_zeros):
|
||||
_check_arraylike('poly', seq_of_zeros)
|
||||
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
|
||||
seq_of_zeros = atleast_1d(seq_of_zeros)
|
||||
|
||||
sh = seq_of_zeros.shape
|
||||
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
|
||||
# import at runtime to avoid circular import
|
||||
from jax._src.numpy import linalg
|
||||
seq_of_zeros = linalg.eigvals(seq_of_zeros)
|
||||
|
||||
if seq_of_zeros.ndim != 1:
|
||||
raise ValueError("input must be 1d or non-empty square 2d array.")
|
||||
|
||||
dt = seq_of_zeros.dtype
|
||||
if len(seq_of_zeros) == 0:
|
||||
return ones((), dtype=dt)
|
||||
|
||||
a = ones((1,), dtype=dt)
|
||||
for k in range(len(seq_of_zeros)):
|
||||
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@_wraps(np.polyval, lax_description="""\
|
||||
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
|
||||
have a major impact on performance for evaluating high-order polynomials. The
|
||||
parameter controls the number of unrolled steps with ``lax.scan`` inside the
|
||||
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
|
||||
improve runtime performance on accelerators, at the cost of increased
|
||||
compilation time.
|
||||
""")
|
||||
@partial(jit, static_argnames=['unroll'])
|
||||
def polyval(p, x, *, unroll=16):
|
||||
_check_arraylike("polyval", p, x)
|
||||
p, x = _promote_dtypes_inexact(p, x)
|
||||
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
|
||||
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
|
||||
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
|
||||
return y
|
||||
|
||||
@_wraps(np.polyadd)
|
||||
@jit
|
||||
def polyadd(a1, a2):
|
||||
_check_arraylike("polyadd", a1, a2)
|
||||
a1, a2 = _promote_dtypes(a1, a2)
|
||||
if a2.shape[0] <= a1.shape[0]:
|
||||
return a1.at[-a2.shape[0]:].add(a2)
|
||||
else:
|
||||
return a2.at[-a1.shape[0]:].add(a1)
|
||||
|
||||
|
||||
@_wraps(np.polyint)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyint(p, m=1, k=None):
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
|
||||
k = 0 if k is None else k
|
||||
_check_arraylike("polyint", p, k)
|
||||
p, k = _promote_dtypes_inexact(p, k)
|
||||
if m < 0:
|
||||
raise ValueError("Order of integral must be positive (see polyder)")
|
||||
k = atleast_1d(k)
|
||||
if len(k) == 1:
|
||||
k = full((m,), k[0])
|
||||
if k.shape != (m,):
|
||||
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
|
||||
if m == 0:
|
||||
return p
|
||||
else:
|
||||
coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
|
||||
return true_divide(concatenate((p, k)), coeff)
|
||||
|
||||
|
||||
@_wraps(np.polyder)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyder(p, m=1):
|
||||
_check_arraylike("polyder", p)
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
|
||||
p, = _promote_dtypes_inexact(p)
|
||||
if m < 0:
|
||||
raise ValueError("Order of derivative must be positive")
|
||||
if m == 0:
|
||||
return p
|
||||
coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0)
|
||||
return p[:-m] * coeff
|
||||
|
||||
|
||||
_LEADING_ZEROS_DOC = """\
|
||||
Setting trim_leading_zeros=True makes the output match that of numpy.
|
||||
But prevents the function from being able to be used in compiled code.
|
||||
"""
|
||||
|
||||
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
|
||||
def polymul(a1, a2, *, trim_leading_zeros=False):
|
||||
_check_arraylike("polymul", a1, a2)
|
||||
a1, a2 = _promote_dtypes_inexact(a1, a2)
|
||||
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
|
||||
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
|
||||
if len(a1) == 0:
|
||||
a1 = asarray([0.])
|
||||
if len(a2) == 0:
|
||||
a2 = asarray([0.])
|
||||
val = convolve(a1, a2, mode='full')
|
||||
return val
|
||||
|
||||
|
||||
@_wraps(np.polysub)
|
||||
@jit
|
||||
def polysub(a1, a2):
|
||||
_check_arraylike("polysub", a1, a2)
|
||||
a1, a2 = _promote_dtypes(a1, a2)
|
||||
return polyadd(a1, -a2)
|
||||
|
@ -211,14 +211,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
percentile as percentile,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
poly as poly,
|
||||
polyadd as polyadd,
|
||||
polyder as polyder,
|
||||
polyfit as polyfit,
|
||||
polyint as polyint,
|
||||
polymul as polymul,
|
||||
polysub as polysub,
|
||||
polyval as polyval,
|
||||
printoptions as printoptions,
|
||||
prod as prod,
|
||||
product as product,
|
||||
@ -304,6 +296,18 @@ from jax._src.numpy.index_tricks import (
|
||||
s_ as s_,
|
||||
)
|
||||
|
||||
from jax._src.numpy.polynomial import (
|
||||
poly as poly,
|
||||
polyadd as polyadd,
|
||||
polyder as polyder,
|
||||
polyfit as polyfit,
|
||||
polyint as polyint,
|
||||
polymul as polymul,
|
||||
polysub as polysub,
|
||||
polyval as polyval,
|
||||
roots as roots,
|
||||
)
|
||||
|
||||
from jax._src.numpy.ufuncs import (
|
||||
abs as abs,
|
||||
absolute as absolute,
|
||||
@ -395,7 +399,6 @@ from jax._src.numpy.ufuncs import (
|
||||
true_divide as true_divide,
|
||||
)
|
||||
|
||||
from jax._src.numpy.polynomial import roots as roots
|
||||
from jax._src.numpy.vectorize import vectorize as vectorize
|
||||
|
||||
# TODO(phawkins): remove this import after fixing users.
|
||||
|
Loading…
x
Reference in New Issue
Block a user