jax.lax: improve docs for floor, ceil, round.

This commit is contained in:
Jake VanderPlas 2025-02-03 10:19:22 -08:00
parent 7e353913f2
commit 49c4020f0a

View File

@ -367,12 +367,46 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
"""
return nextafter_p.bind(x1, x2)
@export
def floor(x: ArrayLike) -> Array:
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.
This function lowers directly to the `stablehlo.floor`_ operation.
Args:
x: input array. Must be have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
to the next integer toward negative infinity.
See also:
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
- :func:`jax.lax.round`: round to the nearest integer
.. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor
"""
return floor_p.bind(x)
@export
def ceil(x: ArrayLike) -> Array:
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.
This function lowers directly to the `stablehlo.ceil`_ operation.
Args:
x: input array. Must be have floating-point type.
Returns:
Array of same shape and dtype as ``x``, containing values rounded
to the next integer toward positive infinity.
See also:
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
- :func:`jax.lax.round`: round to the nearest integer
.. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil
"""
return ceil_p.bind(x)
class RoundingMethod(enum.IntEnum):
@ -388,20 +422,38 @@ class RoundingMethod(enum.IntEnum):
as bankers rounding (e.g., 0.5 -> 0, 1.5 -> 2).
"""
@export
def round(x: ArrayLike,
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
) -> Array:
r"""Elementwise round.
Rounds values to the nearest integer.
Rounds values to the nearest integer. This function lowers directly to the
`stablehlo.round`_ operation.
Args:
x: an array or scalar value to round.
x: an array or scalar value to round. Must have floating-point type.
rounding_method: the method to use when rounding halfway values
(e.g., `0.5`). See :class:`jax.lax.RoundingMethod` for possible values.
(e.g., ``0.5``). See :class:`jax.lax.RoundingMethod` for possible values.
Returns:
An array containing the elementwise rounding of x.
An array of the same shape and dtype as ``x``, containing the elementwise
rounding of ``x``.
See also:
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
Examples:
>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
>>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO
Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32)
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)
.. _stablehlo.round: https://openxla.org/stablehlo/spec#round
"""
rounding_method = RoundingMethod(rounding_method)
return round_p.bind(x, rounding_method=rounding_method)