Merge pull request #21991 from rajasekharporeddy:testbranch4

PiperOrigin-RevId: 645770273
This commit is contained in:
jax authors 2024-06-22 21:28:10 -07:00
commit 348cbba6b2

View File

@ -328,9 +328,53 @@ def polyadd(a1: Array, a2: Array) -> Array:
return a2.at[-a1.shape[0]:].add(a1)
@implements(np.polyint)
@partial(jit, static_argnames=('m',))
def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
r"""Returns the coefficients of the integration of specified order of a polynomial.
JAX implementation of :func:`numpy.polyint`.
Args:
p: An array of polynomial coefficients.
m: Order of integration. Default is 1. It must be specified statically.
k: Scalar or array of ``m`` integration constant (s).
Returns:
An array of coefficients of integrated polynomial.
See also:
- :func:`jax.numpy.polyder`: Computes the coefficients of the derivative of
a polynomial.
- :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values.
Examples:
The first order integration of the polynomial :math:`12 x^2 + 12 x + 6` is
:math:`4 x^3 + 6 x^2 + 6 x`.
>>> p = jnp.array([12, 12, 6])
>>> jnp.polyint(p)
Array([4., 6., 6., 0.], dtype=float32)
Since the constant ``k`` is not provided, the result included ``0`` at the end.
If the constant ``k`` is provided:
>>> jnp.polyint(p, k=4)
Array([4., 6., 6., 4.], dtype=float32)
and the second order integration is :math:`x^4 + 2 x^3 + 3 x`:
>>> jnp.polyint(p, m=2)
Array([1., 2., 3., 0., 0.], dtype=float32)
When ``m>=2``, the constants ``k`` should be provided as an array having
``m`` elements. The second order integration of the polynomial
:math:`12 x^2 + 12 x + 6` with the constants ``k=[4, 5]`` is
:math:`x^4 + 2 x^3 + 3 x^2 + 4 x + 5`:
>>> jnp.polyint(p, m=2, k=jnp.array([4, 5]))
Array([1., 2., 3., 4., 5.], dtype=float32)
"""
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
check_arraylike("polyint", p, k)
@ -351,9 +395,43 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
return true_divide(concatenate((p, k_arr)), coeff)
@implements(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p: Array, m: int = 1) -> Array:
r"""Returns the coefficients of the derivative of specified order of a polynomial.
JAX implementation of :func:`numpy.polyder`.
Args:
p: Array of polynomials coefficients.
m: Order of differentiation (positive integer). Default is 1. It must be
specified statically.
Returns:
An array of polynomial coefficients representing the derivative.
Note:
:func:`jax.numpy.polyder` differs from :func:`numpy.polyder` when an integer
array is given. NumPy returns the result with dtype ``int`` whereas JAX
returns the result with dtype ``float``.
See also:
- :func:`jax.numpy.polyint`: Computes the integral of polynomial.
- :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values.
Examples:
The first order derivative of the polynomial :math:`2 x^3 - 5 x^2 + 3 x - 1`
is :math:`6 x^2 - 10 x +3`:
>>> p = jnp.array([2, -5, 3, -1])
>>> jnp.polyder(p)
Array([ 6., -10., 3.], dtype=float32)
and its second order derivative is :math:`12 x - 10`:
>>> jnp.polyder(p, m=2)
Array([ 12., -10.], dtype=float32)
"""
check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = promote_dtypes_inexact(p)