Improve docs for jnp.roots and jnp.polyfit

This commit is contained in:
rajasekharporeddy 2024-06-24 16:39:55 +05:30
parent 2b728d55b6
commit eba891e3fa

View File

@ -57,33 +57,46 @@ def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
@implements(np.roots, lax_description="""\
Unlike the numpy version of this function, the JAX version returns the roots in
a complex array regardless of the values of the roots. Additionally, the jax
version of this function adds the ``strip_zeros`` function which must be set to
False for the function to be compatible with JIT and other JAX transformations.
With ``strip_zeros=False``, if your coefficients have leading zeros, the
roots will be padded with NaN values:
>>> coeffs = jnp.array([0, 1, 2])
# The default behavior matches numpy and strips leading zeros:
>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)
# With strip_zeros=False, extra roots are set to NaN:
>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)
""",
extra_params="""
strip_zeros : bool, default=True
If set to True, then leading zeros in the coefficients will be stripped, similar
to :func:`numpy.roots`. If set to False, leading zeros will not be stripped, and
undefined roots will be represented by NaN values in the function output.
``strip_zeros`` must be set to ``False`` for the function to be compatible with
:func:`jax.jit` and other JAX transformations.
""")
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
r"""Returns the roots of a polynomial given the coefficients ``p``.
JAX implementations of :func:`numpy.roots`.
Args:
p: Array of polynomial coefficients having rank-1.
strip_zeros : bool, default=True. If True, then leading zeros in the
coefficients will be stripped, similar to :func:`numpy.roots`. If set to
False, leading zeros will not be stripped, and undefined roots will be
represented by NaN values in the function output. ``strip_zeros`` must be
set to ``False`` for the function to be compatible with :func:`jax.jit` and
other JAX transformations.
Returns:
An array containing the roots of the polynomial.
Note:
Unlike ``np.roots`` of this function, the ``jnp.roots`` returns the roots
in a complex array regardless of the values of the roots.
See Also:
- :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given
sequence of roots.
- :func:`jax.numpy.polyfit`: Least squares polynomial fit to data.
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
Examples:
>>> coeffs = jnp.array([0, 1, 2])
The default behavior matches numpy and strips leading zeros:
>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)
With ``strip_zeros=False``, extra roots are set to NaN:
>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)
"""
check_arraylike("roots", p)
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
if p_arr.ndim != 1:
@ -102,15 +115,108 @@ def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
return _roots_with_zeros(p_arr, num_leading_zeros)
_POLYFIT_DOC = """\
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
Also, it works best on rcond <= 10e-3 values.
"""
@implements(np.polyfit, lax_description=_POLYFIT_DOC)
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
full: bool = False, w: Array | None = None, cov: bool = False
) -> Array | tuple[Array, ...]:
r"""Least squares polynomial fit to data.
Jax implementation of :func:`numpy.polyfit`.
Given a set of data points ``(x, y)`` and degree of polynomial ``deg``, the
function finds a polynomial equation of the form:
.. math::
y = p(x) = p[0] x^{deg} + p[1] x^{deg - 1} + ... + p[deg]
Args:
x: Array of data points of shape ``(M,)``.
y: Array of data points of shape ``(M,)`` or ``(M, K)``.
deg: Degree of the polynomials. It must be specified statically.
rcond: Relative condition number of the fit. Default value is ``len(x) * eps``.
It must be specified statically.
full: Switch that controls the return value. Default is ``False`` which
restricts the return value to the array of polynomail coefficients ``p``.
If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``.
It must be specified statically.
w: Array of weights of shape ``(M,)``. If None, all data points are considered
to have equal weight. If not None, the weight :math:`w_i` is applied to the
unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where
:math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None.
cov: Boolean or string. If ``True``, returns the covariance matrix scaled
by ``resids/(M-deg-1)`` along with ploynomial coefficients. If
``cov='unscaled'``, returns the unscaaled version of covariance matrix.
Default is ``False``. ``cov`` is ignored if ``full=True``. It must be
specified statically.
Returns:
- An array polynomial coefficients ``p`` if ``full=False`` and ``cov=False``.
- A tuple of arrays ``(p, resids, rank, s, rcond)`` if ``full=True``. Where
- ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial
coefficients.
- ``resids`` is the sum of squared residual of shape () or (K,).
- ``rank`` is the rank of the matrix ``x``.
- ``s`` is the singular values of the matrix ``x``.
- ``rcond`` as the array.
- A tuple of arrays ``(p, C)`` if ``full=False`` and ``cov=True``. Where
- ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial
coefficients.
- ``C`` is the covariance matrix of polynomial coefficients of shape
``(deg + 1, deg + 1)`` or ``(deg + 1, deg + 1, 1)``.
Note:
Unlike :func:`numpy.polyfit` implementation of polyfit, :func:`jax.numpy.polyfit`
will not warn on rank reduction, which indicates an ill conditioned matrix.
See Also:
- :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given
sequence of roots.
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Examples:
>>> x = jnp.array([3., 6., 9., 4.])
>>> y = jnp.array([[0, 1, 2],
... [2, 5, 7],
... [8, 4, 9],
... [1, 6, 3]])
>>> p = jnp.polyfit(x, y, 2)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(p)
[[ 0.2 -0.35 -0.14]
[-1.17 4.47 2.96]
[ 1.95 -8.21 -5.93]]
If ``full=True``, returns a tuple of arrays as follows:
>>> p, resids, rank, s, rcond = jnp.polyfit(x, y, 2, full=True)
>>> with jnp.printoptions(precision=2, suppress=True):
... print("Polynomial Coefficients:", "\n", p, "\n",
... "Residuals:", resids, "\n",
... "Rank:", rank, "\n",
... "s:", s, "\n",
... "rcond:", rcond)
Polynomial Coefficients:
[[ 0.2 -0.35 -0.14]
[-1.17 4.47 2.96]
[ 1.95 -8.21 -5.93]]
Residuals: [0.37 5.94 0.61]
Rank: 3
s: [1.67 0.47 0.04]
rcond: 4.7683716e-07
If ``cov=True`` and ``full=False``, returns a tuple of arrays having
polynomial coefficients and covariance matrix.
>>> p, C = jnp.polyfit(x, y, 2, cov=True)
>>> p.shape, C.shape
((3, 3), (3, 3, 1))
"""
check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1