mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Improve docs for jnp.poly and polyval
This commit is contained in:
parent
ed4958cb3e
commit
c5de7bb92e
@ -175,21 +175,63 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None,
|
||||
return c
|
||||
|
||||
|
||||
_POLY_DOC = """\
|
||||
This differs from np.poly when an integer array is given.
|
||||
np.poly returns a result with dtype float64 in this case.
|
||||
jax returns a result with an inexact type, but not necessarily
|
||||
float64.
|
||||
|
||||
This also differs from np.poly when the input array strictly
|
||||
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
|
||||
np.poly returns an array with a real dtype in such cases.
|
||||
jax returns an array with a complex dtype in such cases.
|
||||
"""
|
||||
|
||||
@implements(np.poly, lax_description=_POLY_DOC)
|
||||
@jit
|
||||
def poly(seq_of_zeros: Array) -> Array:
|
||||
r"""Returns the coefficients of a polynomial for the given sequence of roots.
|
||||
|
||||
JAX implementation of :func:`numpy.poly`.
|
||||
|
||||
Args:
|
||||
seq_of_zeros: A scalar or an array of roots of the polynomial of shape ``(M,)``
|
||||
or ``(M, M)``.
|
||||
|
||||
Returns:
|
||||
An array containing the coefficients of the polynomial. The dtype of the
|
||||
output is always promoted to inexact.
|
||||
|
||||
Note:
|
||||
|
||||
:func:`jax.numpy.poly` differs from :func:`numpy.poly`:
|
||||
|
||||
- When the input is a scalar, ``np.poly`` raises a ``TypeError``, whereas
|
||||
``jnp.poly`` treats scalars the same as length-1 arrays.
|
||||
- For complex-valued or square-shaped inputs, ``jnp.poly`` always returns
|
||||
complex coefficients, whereas ``np.poly`` may return real or complex
|
||||
depending on their values.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polyfit`: Least squares polynomial fit.
|
||||
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
|
||||
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
|
||||
coefficients.
|
||||
|
||||
Example:
|
||||
|
||||
Scalar inputs:
|
||||
|
||||
>>> jnp.poly(1)
|
||||
Array([ 1., -1.], dtype=float32)
|
||||
|
||||
Input array with integer values:
|
||||
|
||||
>>> x = jnp.array([1, 2, 3])
|
||||
>>> jnp.poly(x)
|
||||
Array([ 1., -6., 11., -6.], dtype=float32)
|
||||
|
||||
Input array with complex conjugates:
|
||||
|
||||
>>> x = jnp.array([2, 1+2j, 1-2j])
|
||||
>>> jnp.poly(x)
|
||||
Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64)
|
||||
|
||||
Input array as square matrix with real valued inputs:
|
||||
|
||||
>>> x = jnp.array([[2, 1, 5],
|
||||
... [3, 4, 7],
|
||||
... [1, 3, 5]])
|
||||
>>> jnp.round(jnp.poly(x))
|
||||
Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)
|
||||
"""
|
||||
check_arraylike('poly', seq_of_zeros)
|
||||
seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros)
|
||||
seq_of_zeros = atleast_1d(seq_of_zeros)
|
||||
@ -214,16 +256,60 @@ def poly(seq_of_zeros: Array) -> Array:
|
||||
return a
|
||||
|
||||
|
||||
@implements(np.polyval, lax_description="""\
|
||||
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
|
||||
have a major impact on performance for evaluating high-order polynomials. The
|
||||
parameter controls the number of unrolled steps with ``lax.scan`` inside the
|
||||
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
|
||||
improve runtime performance on accelerators, at the cost of increased
|
||||
compilation time.
|
||||
""")
|
||||
@partial(jit, static_argnames=['unroll'])
|
||||
def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
|
||||
r"""Evaluates the polynomial at specific values.
|
||||
|
||||
JAX implementations of :func:`numpy.polyval`.
|
||||
|
||||
For the 1D-polynomial coefficients ``p`` of length ``M``, the function returns
|
||||
the value:
|
||||
|
||||
.. math::
|
||||
|
||||
p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}
|
||||
|
||||
Args:
|
||||
p: An array of polynomial coefficients of shape ``(M,)``.
|
||||
x: A number or an array of numbers.
|
||||
unroll: A number used to control the number of unrolled steps with
|
||||
``lax.scan``. It must be specified statically.
|
||||
|
||||
Returns:
|
||||
An array of same shape as ``x``.
|
||||
|
||||
Note:
|
||||
|
||||
The ``unroll`` parameter is JAX specific. It does not affect correctness but
|
||||
can have a major impact on performance for evaluating high-order polynomials.
|
||||
The parameter controls the number of unrolled steps with ``lax.scan`` inside
|
||||
the ``jnp.polyval`` implementation. Consider setting ``unroll=128`` (or even
|
||||
higher) to improve runtime performance on accelerators, at the cost of
|
||||
increased compilation time.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polyfit`: Least squares polynomial fit.
|
||||
- :func:`jax.numpy.poly`: Finds the coefficients of a polynomial with given
|
||||
roots.
|
||||
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
|
||||
coefficients.
|
||||
|
||||
Example:
|
||||
>>> p = jnp.array([2, 5, 1])
|
||||
>>> jnp.polyval(p, 3)
|
||||
Array(34., dtype=float32)
|
||||
|
||||
If ``x`` is a 2D array, ``polyval`` returns 2D-array with same shape as
|
||||
that of ``x``:
|
||||
|
||||
>>> x = jnp.array([[2, 1, 5],
|
||||
... [3, 4, 7],
|
||||
... [1, 3, 5]])
|
||||
>>> jnp.polyval(p, x)
|
||||
Array([[ 19., 8., 76.],
|
||||
[ 34., 53., 134.],
|
||||
[ 8., 34., 76.]], dtype=float32)
|
||||
"""
|
||||
check_arraylike("polyval", p, x)
|
||||
p, x = promote_dtypes_inexact(p, x)
|
||||
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user