mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Improve docs for jnp.roots and jnp.polyfit
This commit is contained in:
parent
2b728d55b6
commit
eba891e3fa
@ -57,33 +57,46 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
|
||||
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
|
||||
|
||||
|
||||
@implements(np.roots, lax_description="""\
|
||||
Unlike the numpy version of this function, the JAX version returns the roots in
|
||||
a complex array regardless of the values of the roots. Additionally, the jax
|
||||
version of this function adds the ``strip_zeros`` function which must be set to
|
||||
False for the function to be compatible with JIT and other JAX transformations.
|
||||
With ``strip_zeros=False``, if your coefficients have leading zeros, the
|
||||
roots will be padded with NaN values:
|
||||
|
||||
>>> coeffs = jnp.array([0, 1, 2])
|
||||
|
||||
# The default behavior matches numpy and strips leading zeros:
|
||||
>>> jnp.roots(coeffs)
|
||||
Array([-2.+0.j], dtype=complex64)
|
||||
|
||||
# With strip_zeros=False, extra roots are set to NaN:
|
||||
>>> jnp.roots(coeffs, strip_zeros=False)
|
||||
Array([-2. +0.j, nan+nanj], dtype=complex64)
|
||||
""",
|
||||
extra_params="""
|
||||
strip_zeros : bool, default=True
|
||||
If set to True, then leading zeros in the coefficients will be stripped, similar
|
||||
to :func:`numpy.roots`. If set to False, leading zeros will not be stripped, and
|
||||
undefined roots will be represented by NaN values in the function output.
|
||||
``strip_zeros`` must be set to ``False`` for the function to be compatible with
|
||||
:func:`jax.jit` and other JAX transformations.
|
||||
""")
|
||||
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
|
||||
r"""Returns the roots of a polynomial given the coefficients ``p``.
|
||||
|
||||
JAX implementations of :func:`numpy.roots`.
|
||||
|
||||
Args:
|
||||
p: Array of polynomial coefficients having rank-1.
|
||||
strip_zeros : bool, default=True. If True, then leading zeros in the
|
||||
coefficients will be stripped, similar to :func:`numpy.roots`. If set to
|
||||
False, leading zeros will not be stripped, and undefined roots will be
|
||||
represented by NaN values in the function output. ``strip_zeros`` must be
|
||||
set to ``False`` for the function to be compatible with :func:`jax.jit` and
|
||||
other JAX transformations.
|
||||
|
||||
Returns:
|
||||
An array containing the roots of the polynomial.
|
||||
|
||||
Note:
|
||||
Unlike ``np.roots`` of this function, the ``jnp.roots`` returns the roots
|
||||
in a complex array regardless of the values of the roots.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given
|
||||
sequence of roots.
|
||||
- :func:`jax.numpy.polyfit`: Least squares polynomial fit to data.
|
||||
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
|
||||
|
||||
Examples:
|
||||
>>> coeffs = jnp.array([0, 1, 2])
|
||||
|
||||
The default behavior matches numpy and strips leading zeros:
|
||||
|
||||
>>> jnp.roots(coeffs)
|
||||
Array([-2.+0.j], dtype=complex64)
|
||||
|
||||
With ``strip_zeros=False``, extra roots are set to NaN:
|
||||
|
||||
>>> jnp.roots(coeffs, strip_zeros=False)
|
||||
Array([-2. +0.j, nan+nanj], dtype=complex64)
|
||||
"""
|
||||
check_arraylike("roots", p)
|
||||
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
|
||||
if p_arr.ndim != 1:
|
||||
@ -102,15 +115,108 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
|
||||
return _roots_with_zeros(p_arr, num_leading_zeros)
|
||||
|
||||
|
||||
_POLYFIT_DOC = """\
|
||||
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
|
||||
Also, it works best on rcond <= 10e-3 values.
|
||||
"""
|
||||
@implements(np.polyfit, lax_description=_POLYFIT_DOC)
|
||||
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
|
||||
def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
|
||||
full: bool = False, w: Array | None = None, cov: bool = False
|
||||
) -> Array | tuple[Array, ...]:
|
||||
r"""Least squares polynomial fit to data.
|
||||
|
||||
Jax implementation of :func:`numpy.polyfit`.
|
||||
|
||||
Given a set of data points ``(x, y)`` and degree of polynomial ``deg``, the
|
||||
function finds a polynomial equation of the form:
|
||||
|
||||
.. math::
|
||||
|
||||
y = p(x) = p[0] x^{deg} + p[1] x^{deg - 1} + ... + p[deg]
|
||||
|
||||
Args:
|
||||
x: Array of data points of shape ``(M,)``.
|
||||
y: Array of data points of shape ``(M,)`` or ``(M, K)``.
|
||||
deg: Degree of the polynomials. It must be specified statically.
|
||||
rcond: Relative condition number of the fit. Default value is ``len(x) * eps``.
|
||||
It must be specified statically.
|
||||
full: Switch that controls the return value. Default is ``False`` which
|
||||
restricts the return value to the array of polynomail coefficients ``p``.
|
||||
If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``.
|
||||
It must be specified statically.
|
||||
w: Array of weights of shape ``(M,)``. If None, all data points are considered
|
||||
to have equal weight. If not None, the weight :math:`w_i` is applied to the
|
||||
unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where
|
||||
:math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None.
|
||||
cov: Boolean or string. If ``True``, returns the covariance matrix scaled
|
||||
by ``resids/(M-deg-1)`` along with ploynomial coefficients. If
|
||||
``cov='unscaled'``, returns the unscaaled version of covariance matrix.
|
||||
Default is ``False``. ``cov`` is ignored if ``full=True``. It must be
|
||||
specified statically.
|
||||
|
||||
Returns:
|
||||
- An array polynomial coefficients ``p`` if ``full=False`` and ``cov=False``.
|
||||
|
||||
- A tuple of arrays ``(p, resids, rank, s, rcond)`` if ``full=True``. Where
|
||||
|
||||
- ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial
|
||||
coefficients.
|
||||
- ``resids`` is the sum of squared residual of shape () or (K,).
|
||||
- ``rank`` is the rank of the matrix ``x``.
|
||||
- ``s`` is the singular values of the matrix ``x``.
|
||||
- ``rcond`` as the array.
|
||||
- A tuple of arrays ``(p, C)`` if ``full=False`` and ``cov=True``. Where
|
||||
|
||||
- ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial
|
||||
coefficients.
|
||||
- ``C`` is the covariance matrix of polynomial coefficients of shape
|
||||
``(deg + 1, deg + 1)`` or ``(deg + 1, deg + 1, 1)``.
|
||||
|
||||
Note:
|
||||
Unlike :func:`numpy.polyfit` implementation of polyfit, :func:`jax.numpy.polyfit`
|
||||
will not warn on rank reduction, which indicates an ill conditioned matrix.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given
|
||||
sequence of roots.
|
||||
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
|
||||
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
|
||||
coefficients.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([3., 6., 9., 4.])
|
||||
>>> y = jnp.array([[0, 1, 2],
|
||||
... [2, 5, 7],
|
||||
... [8, 4, 9],
|
||||
... [1, 6, 3]])
|
||||
>>> p = jnp.polyfit(x, y, 2)
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(p)
|
||||
[[ 0.2 -0.35 -0.14]
|
||||
[-1.17 4.47 2.96]
|
||||
[ 1.95 -8.21 -5.93]]
|
||||
|
||||
If ``full=True``, returns a tuple of arrays as follows:
|
||||
|
||||
>>> p, resids, rank, s, rcond = jnp.polyfit(x, y, 2, full=True)
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print("Polynomial Coefficients:", "\n", p, "\n",
|
||||
... "Residuals:", resids, "\n",
|
||||
... "Rank:", rank, "\n",
|
||||
... "s:", s, "\n",
|
||||
... "rcond:", rcond)
|
||||
Polynomial Coefficients:
|
||||
[[ 0.2 -0.35 -0.14]
|
||||
[-1.17 4.47 2.96]
|
||||
[ 1.95 -8.21 -5.93]]
|
||||
Residuals: [0.37 5.94 0.61]
|
||||
Rank: 3
|
||||
s: [1.67 0.47 0.04]
|
||||
rcond: 4.7683716e-07
|
||||
|
||||
If ``cov=True`` and ``full=False``, returns a tuple of arrays having
|
||||
polynomial coefficients and covariance matrix.
|
||||
|
||||
>>> p, C = jnp.polyfit(x, y, 2, cov=True)
|
||||
>>> p.shape, C.shape
|
||||
((3, 3), (3, 3, 1))
|
||||
"""
|
||||
check_arraylike("polyfit", x, y)
|
||||
deg = core.concrete_or_error(int, deg, "deg must be int")
|
||||
order = deg + 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user