mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23995 from jakevdp:trapezoid-doc
PiperOrigin-RevId: 680734292
This commit is contained in:
commit
c557db0bd8
@ -6692,10 +6692,48 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
|
||||
return take(a, gather_indices, axis=axis)
|
||||
|
||||
|
||||
@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None)))
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||
axis: int = -1) -> Array:
|
||||
r"""
|
||||
Integrate along the given axis using the composite trapezoidal rule.
|
||||
|
||||
JAX implementation of :func:`numpy.trapezoid`
|
||||
|
||||
The trapezoidal rule approximates the integral under a curve by summing the
|
||||
areas of trapezoids formed between adjacent data points.
|
||||
|
||||
Args:
|
||||
y: array of data to integrate.
|
||||
x: optional array of sample points corresponding to the ``y`` values. If not
|
||||
provided, ``x`` defaults to equally spaced with spacing given by ``dx``.
|
||||
dx: The spacing between sample points when `x` is None (default: 1.0).
|
||||
axis: The axis along which to integrate (default: -1)
|
||||
|
||||
Returns:
|
||||
The definite integral approximated by the trapezoidal rule.
|
||||
|
||||
Examples:
|
||||
Integrate over a regular grid, with spacing 1.0:
|
||||
|
||||
>>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
|
||||
>>> jnp.trapezoid(y, dx=1.0)
|
||||
Array(13., dtype=float32)
|
||||
|
||||
Integrate over an irregular grid:
|
||||
|
||||
>>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
|
||||
>>> jnp.trapezoid(y, x)
|
||||
Array(43., dtype=float32)
|
||||
|
||||
Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`:
|
||||
|
||||
>>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
|
||||
>>> y = jnp.sin(x) ** 2
|
||||
>>> result = jnp.trapezoid(y, x)
|
||||
>>> jnp.allclose(result, jnp.pi)
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
# TODO(phawkins): remove this annotation after fixing jnp types.
|
||||
dx_array: Array
|
||||
if x is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user