From c5de7bb92e5b10641de778fcc057c6cdb16369c9 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Sat, 22 Jun 2024 02:49:43 +0530 Subject: [PATCH] Improve docs for jnp.poly and polyval --- jax/_src/numpy/polynomial.py | 128 +++++++++++++++++++++++++++++------ 1 file changed, 107 insertions(+), 21 deletions(-) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 9e82284f7..f0d2d928d 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -175,21 +175,63 @@ def polyfit(x: Array, y: Array, deg: int, rcond: float | None = None, return c -_POLY_DOC = """\ -This differs from np.poly when an integer array is given. -np.poly returns a result with dtype float64 in this case. -jax returns a result with an inexact type, but not necessarily -float64. - -This also differs from np.poly when the input array strictly -contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j]. -np.poly returns an array with a real dtype in such cases. -jax returns an array with a complex dtype in such cases. -""" - -@implements(np.poly, lax_description=_POLY_DOC) @jit def poly(seq_of_zeros: Array) -> Array: + r"""Returns the coefficients of a polynomial for the given sequence of roots. + + JAX implementation of :func:`numpy.poly`. + + Args: + seq_of_zeros: A scalar or an array of roots of the polynomial of shape ``(M,)`` + or ``(M, M)``. + + Returns: + An array containing the coefficients of the polynomial. The dtype of the + output is always promoted to inexact. + + Note: + + :func:`jax.numpy.poly` differs from :func:`numpy.poly`: + + - When the input is a scalar, ``np.poly`` raises a ``TypeError``, whereas + ``jnp.poly`` treats scalars the same as length-1 arrays. + - For complex-valued or square-shaped inputs, ``jnp.poly`` always returns + complex coefficients, whereas ``np.poly`` may return real or complex + depending on their values. + + See also: + - :func:`jax.numpy.polyfit`: Least squares polynomial fit. + - :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Example: + + Scalar inputs: + + >>> jnp.poly(1) + Array([ 1., -1.], dtype=float32) + + Input array with integer values: + + >>> x = jnp.array([1, 2, 3]) + >>> jnp.poly(x) + Array([ 1., -6., 11., -6.], dtype=float32) + + Input array with complex conjugates: + + >>> x = jnp.array([2, 1+2j, 1-2j]) + >>> jnp.poly(x) + Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64) + + Input array as square matrix with real valued inputs: + + >>> x = jnp.array([[2, 1, 5], + ... [3, 4, 7], + ... [1, 3, 5]]) + >>> jnp.round(jnp.poly(x)) + Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64) + """ check_arraylike('poly', seq_of_zeros) seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros) seq_of_zeros = atleast_1d(seq_of_zeros) @@ -214,16 +256,60 @@ def poly(seq_of_zeros: Array) -> Array: return a -@implements(np.polyval, lax_description="""\ -The ``unroll`` parameter is JAX specific. It does not effect correctness but can -have a major impact on performance for evaluating high-order polynomials. The -parameter controls the number of unrolled steps with ``lax.scan`` inside the -``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to -improve runtime performance on accelerators, at the cost of increased -compilation time. -""") @partial(jit, static_argnames=['unroll']) def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array: + r"""Evaluates the polynomial at specific values. + + JAX implementations of :func:`numpy.polyval`. + + For the 1D-polynomial coefficients ``p`` of length ``M``, the function returns + the value: + + .. math:: + + p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1} + + Args: + p: An array of polynomial coefficients of shape ``(M,)``. + x: A number or an array of numbers. + unroll: A number used to control the number of unrolled steps with + ``lax.scan``. It must be specified statically. + + Returns: + An array of same shape as ``x``. + + Note: + + The ``unroll`` parameter is JAX specific. It does not affect correctness but + can have a major impact on performance for evaluating high-order polynomials. + The parameter controls the number of unrolled steps with ``lax.scan`` inside + the ``jnp.polyval`` implementation. Consider setting ``unroll=128`` (or even + higher) to improve runtime performance on accelerators, at the cost of + increased compilation time. + + See also: + - :func:`jax.numpy.polyfit`: Least squares polynomial fit. + - :func:`jax.numpy.poly`: Finds the coefficients of a polynomial with given + roots. + - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given + coefficients. + + Example: + >>> p = jnp.array([2, 5, 1]) + >>> jnp.polyval(p, 3) + Array(34., dtype=float32) + + If ``x`` is a 2D array, ``polyval`` returns 2D-array with same shape as + that of ``x``: + + >>> x = jnp.array([[2, 1, 5], + ... [3, 4, 7], + ... [1, 3, 5]]) + >>> jnp.polyval(p, x) + Array([[ 19., 8., 76.], + [ 34., 53., 134.], + [ 8., 34., 76.]], dtype=float32) + """ check_arraylike("polyval", p, x) p, x = promote_dtypes_inexact(p, x) shape = lax.broadcast_shapes(p.shape[1:], x.shape)