diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index f0d2d928d..d28a424f3 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -328,9 +328,53 @@ def polyadd(a1: Array, a2: Array) -> Array: return a2.at[-a1.shape[0]:].add(a1) -@implements(np.polyint) @partial(jit, static_argnames=('m',)) def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: + r"""Returns the coefficients of the integration of specified order of a polynomial. + + JAX implementation of :func:`numpy.polyint`. + + Args: + p: An array of polynomial coefficients. + m: Order of integration. Default is 1. It must be specified statically. + k: Scalar or array of ``m`` integration constant (s). + + Returns: + An array of coefficients of integrated polynomial. + + See also: + - :func:`jax.numpy.polyder`: Computes the coefficients of the derivative of + a polynomial. + - :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values. + + Examples: + + The first order integration of the polynomial :math:`12 x^2 + 12 x + 6` is + :math:`4 x^3 + 6 x^2 + 6 x`. + + >>> p = jnp.array([12, 12, 6]) + >>> jnp.polyint(p) + Array([4., 6., 6., 0.], dtype=float32) + + Since the constant ``k`` is not provided, the result included ``0`` at the end. + If the constant ``k`` is provided: + + >>> jnp.polyint(p, k=4) + Array([4., 6., 6., 4.], dtype=float32) + + and the second order integration is :math:`x^4 + 2 x^3 + 3 x`: + + >>> jnp.polyint(p, m=2) + Array([1., 2., 3., 0., 0.], dtype=float32) + + When ``m>=2``, the constants ``k`` should be provided as an array having + ``m`` elements. The second order integration of the polynomial + :math:`12 x^2 + 12 x + 6` with the constants ``k=[4, 5]`` is + :math:`x^4 + 2 x^3 + 3 x^2 + 4 x + 5`: + + >>> jnp.polyint(p, m=2, k=jnp.array([4, 5])) + Array([1., 2., 3., 4., 5.], dtype=float32) + """ m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k check_arraylike("polyint", p, k) @@ -351,9 +395,43 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array: return true_divide(concatenate((p, k_arr)), coeff) -@implements(np.polyder) @partial(jit, static_argnames=('m',)) def polyder(p: Array, m: int = 1) -> Array: + r"""Returns the coefficients of the derivative of specified order of a polynomial. + + JAX implementation of :func:`numpy.polyder`. + + Args: + p: Array of polynomials coefficients. + m: Order of differentiation (positive integer). Default is 1. It must be + specified statically. + + Returns: + An array of polynomial coefficients representing the derivative. + + Note: + :func:`jax.numpy.polyder` differs from :func:`numpy.polyder` when an integer + array is given. NumPy returns the result with dtype ``int`` whereas JAX + returns the result with dtype ``float``. + + See also: + - :func:`jax.numpy.polyint`: Computes the integral of polynomial. + - :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values. + + Examples: + + The first order derivative of the polynomial :math:`2 x^3 - 5 x^2 + 3 x - 1` + is :math:`6 x^2 - 10 x +3`: + + >>> p = jnp.array([2, -5, 3, -1]) + >>> jnp.polyder(p) + Array([ 6., -10., 3.], dtype=float32) + + and its second order derivative is :math:`12 x - 10`: + + >>> jnp.polyder(p, m=2) + Array([ 12., -10.], dtype=float32) + """ check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") p, = promote_dtypes_inexact(p)