mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #21991 from rajasekharporeddy:testbranch4
PiperOrigin-RevId: 645770273
This commit is contained in:
commit
348cbba6b2
@ -328,9 +328,53 @@ def polyadd(a1: Array, a2: Array) -> Array:
|
||||
return a2.at[-a1.shape[0]:].add(a1)
|
||||
|
||||
|
||||
@implements(np.polyint)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
|
||||
r"""Returns the coefficients of the integration of specified order of a polynomial.
|
||||
|
||||
JAX implementation of :func:`numpy.polyint`.
|
||||
|
||||
Args:
|
||||
p: An array of polynomial coefficients.
|
||||
m: Order of integration. Default is 1. It must be specified statically.
|
||||
k: Scalar or array of ``m`` integration constant (s).
|
||||
|
||||
Returns:
|
||||
An array of coefficients of integrated polynomial.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polyder`: Computes the coefficients of the derivative of
|
||||
a polynomial.
|
||||
- :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values.
|
||||
|
||||
Examples:
|
||||
|
||||
The first order integration of the polynomial :math:`12 x^2 + 12 x + 6` is
|
||||
:math:`4 x^3 + 6 x^2 + 6 x`.
|
||||
|
||||
>>> p = jnp.array([12, 12, 6])
|
||||
>>> jnp.polyint(p)
|
||||
Array([4., 6., 6., 0.], dtype=float32)
|
||||
|
||||
Since the constant ``k`` is not provided, the result included ``0`` at the end.
|
||||
If the constant ``k`` is provided:
|
||||
|
||||
>>> jnp.polyint(p, k=4)
|
||||
Array([4., 6., 6., 4.], dtype=float32)
|
||||
|
||||
and the second order integration is :math:`x^4 + 2 x^3 + 3 x`:
|
||||
|
||||
>>> jnp.polyint(p, m=2)
|
||||
Array([1., 2., 3., 0., 0.], dtype=float32)
|
||||
|
||||
When ``m>=2``, the constants ``k`` should be provided as an array having
|
||||
``m`` elements. The second order integration of the polynomial
|
||||
:math:`12 x^2 + 12 x + 6` with the constants ``k=[4, 5]`` is
|
||||
:math:`x^4 + 2 x^3 + 3 x^2 + 4 x + 5`:
|
||||
|
||||
>>> jnp.polyint(p, m=2, k=jnp.array([4, 5]))
|
||||
Array([1., 2., 3., 4., 5.], dtype=float32)
|
||||
"""
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
|
||||
k = 0 if k is None else k
|
||||
check_arraylike("polyint", p, k)
|
||||
@ -351,9 +395,43 @@ def polyint(p: Array, m: int = 1, k: int | None = None) -> Array:
|
||||
return true_divide(concatenate((p, k_arr)), coeff)
|
||||
|
||||
|
||||
@implements(np.polyder)
|
||||
@partial(jit, static_argnames=('m',))
|
||||
def polyder(p: Array, m: int = 1) -> Array:
|
||||
r"""Returns the coefficients of the derivative of specified order of a polynomial.
|
||||
|
||||
JAX implementation of :func:`numpy.polyder`.
|
||||
|
||||
Args:
|
||||
p: Array of polynomials coefficients.
|
||||
m: Order of differentiation (positive integer). Default is 1. It must be
|
||||
specified statically.
|
||||
|
||||
Returns:
|
||||
An array of polynomial coefficients representing the derivative.
|
||||
|
||||
Note:
|
||||
:func:`jax.numpy.polyder` differs from :func:`numpy.polyder` when an integer
|
||||
array is given. NumPy returns the result with dtype ``int`` whereas JAX
|
||||
returns the result with dtype ``float``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.polyint`: Computes the integral of polynomial.
|
||||
- :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values.
|
||||
|
||||
Examples:
|
||||
|
||||
The first order derivative of the polynomial :math:`2 x^3 - 5 x^2 + 3 x - 1`
|
||||
is :math:`6 x^2 - 10 x +3`:
|
||||
|
||||
>>> p = jnp.array([2, -5, 3, -1])
|
||||
>>> jnp.polyder(p)
|
||||
Array([ 6., -10., 3.], dtype=float32)
|
||||
|
||||
and its second order derivative is :math:`12 x - 10`:
|
||||
|
||||
>>> jnp.polyder(p, m=2)
|
||||
Array([ 12., -10.], dtype=float32)
|
||||
"""
|
||||
check_arraylike("polyder", p)
|
||||
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
|
||||
p, = promote_dtypes_inexact(p)
|
||||
|
Loading…
x
Reference in New Issue
Block a user