Improve docs for jnp.poly and polyval

This commit is contained in:
rajasekharporeddy 2024-06-22 02:49:43 +05:30
parent ed4958cb3e
commit c5de7bb92e

View File

@ -175,21 +175,63 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
return c
_POLY_DOC = """\
This differs from np.poly when an integer array is given.
np.poly returns a result with dtype float64 in this case.
jax returns a result with an inexact type, but not necessarily
float64.
This also differs from np.poly when the input array strictly
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
np.poly returns an array with a real dtype in such cases.
jax returns an array with a complex dtype in such cases.
"""
@implements(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros: Array) -> Array:
r"""Returns the coefficients of a polynomial for the given sequence of roots.
JAX implementation of :func:`numpy.poly`.
Args:
seq_of_zeros: A scalar or an array of roots of the polynomial of shape ``(M,)``
or ``(M, M)``.
Returns:
An array containing the coefficients of the polynomial. The dtype of the
output is always promoted to inexact.
Note:
:func:`jax.numpy.poly` differs from :func:`numpy.poly`:
- When the input is a scalar, ``np.poly`` raises a ``TypeError``, whereas
``jnp.poly`` treats scalars the same as length-1 arrays.
- For complex-valued or square-shaped inputs, ``jnp.poly`` always returns
complex coefficients, whereas ``np.poly`` may return real or complex
depending on their values.
See also:
- :func:`jax.numpy.polyfit`: Least squares polynomial fit.
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
Scalar inputs:
>>> jnp.poly(1)
Array([ 1., -1.], dtype=float32)
Input array with integer values:
>>> x = jnp.array([1, 2, 3])
>>> jnp.poly(x)
Array([ 1., -6., 11., -6.], dtype=float32)
Input array with complex conjugates:
>>> x = jnp.array([2, 1+2j, 1-2j])
>>> jnp.poly(x)
Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64)
Input array as square matrix with real valued inputs:
>>> x = jnp.array([[2, 1, 5],
... [3, 4, 7],
... [1, 3, 5]])
>>> jnp.round(jnp.poly(x))
Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)
"""
check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
@ -214,16 +256,60 @@ def poly(seq_of_zeros: Array) -> Array:
return a
@implements(np.polyval, lax_description="""\
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
have a major impact on performance for evaluating high-order polynomials. The
parameter controls the number of unrolled steps with ``lax.scan`` inside the
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
improve runtime performance on accelerators, at the cost of increased
compilation time.
""")
@partial(jit, static_argnames=['unroll'])
def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
r"""Evaluates the polynomial at specific values.
JAX implementations of :func:`numpy.polyval`.
For the 1D-polynomial coefficients ``p`` of length ``M``, the function returns
the value:
.. math::
p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}
Args:
p: An array of polynomial coefficients of shape ``(M,)``.
x: A number or an array of numbers.
unroll: A number used to control the number of unrolled steps with
``lax.scan``. It must be specified statically.
Returns:
An array of same shape as ``x``.
Note:
The ``unroll`` parameter is JAX specific. It does not affect correctness but
can have a major impact on performance for evaluating high-order polynomials.
The parameter controls the number of unrolled steps with ``lax.scan`` inside
the ``jnp.polyval`` implementation. Consider setting ``unroll=128`` (or even
higher) to improve runtime performance on accelerators, at the cost of
increased compilation time.
See also:
- :func:`jax.numpy.polyfit`: Least squares polynomial fit.
- :func:`jax.numpy.poly`: Finds the coefficients of a polynomial with given
roots.
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)
If ``x`` is a 2D array, ``polyval`` returns 2D-array with same shape as
that of ``x``:
>>> x = jnp.array([[2, 1, 5],
... [3, 4, 7],
... [1, 3, 5]])
>>> jnp.polyval(p, x)
Array([[ 19., 8., 76.],
[ 34., 53., 134.],
[ 8., 34., 76.]], dtype=float32)
"""
check_arraylike("polyval", p, x)
p, x = promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)