Merge pull request #23995 from jakevdp:trapezoid-doc

PiperOrigin-RevId: 680734292
This commit is contained in:
jax authors 2024-09-30 15:10:16 -07:00
commit c557db0bd8

View File

@ -6692,10 +6692,48 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *,
return take(a, gather_indices, axis=axis)
@util.implements(getattr(np, "trapezoid", getattr(np, "trapz", None)))
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
r"""
Integrate along the given axis using the composite trapezoidal rule.
JAX implementation of :func:`numpy.trapezoid`
The trapezoidal rule approximates the integral under a curve by summing the
areas of trapezoids formed between adjacent data points.
Args:
y: array of data to integrate.
x: optional array of sample points corresponding to the ``y`` values. If not
provided, ``x`` defaults to equally spaced with spacing given by ``dx``.
dx: The spacing between sample points when `x` is None (default: 1.0).
axis: The axis along which to integrate (default: -1)
Returns:
The definite integral approximated by the trapezoidal rule.
Examples:
Integrate over a regular grid, with spacing 1.0:
>>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
>>> jnp.trapezoid(y, dx=1.0)
Array(13., dtype=float32)
Integrate over an irregular grid:
>>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
>>> jnp.trapezoid(y, x)
Array(43., dtype=float32)
Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`:
>>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
>>> y = jnp.sin(x) ** 2
>>> result = jnp.trapezoid(y, x)
>>> jnp.allclose(result, jnp.pi)
Array(True, dtype=bool)
"""
# TODO(phawkins): remove this annotation after fixing jnp types.
dx_array: Array
if x is None: