mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax.lax: improve docs for floor, ceil, round.
This commit is contained in:
parent
7e353913f2
commit
49c4020f0a
@ -367,12 +367,46 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
"""
|
||||
return nextafter_p.bind(x1, x2)
|
||||
|
||||
@export
|
||||
def floor(x: ArrayLike) -> Array:
|
||||
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`."""
|
||||
r"""Elementwise floor: :math:`\left\lfloor x \right\rfloor`.
|
||||
|
||||
This function lowers directly to the `stablehlo.floor`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must be have floating-point type.
|
||||
|
||||
Returns:
|
||||
Array of same shape and dtype as ``x``, containing values rounded
|
||||
to the next integer toward negative infinity.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
|
||||
- :func:`jax.lax.round`: round to the nearest integer
|
||||
|
||||
.. _stablehlo.floor: https://openxla.org/stablehlo/spec#floor
|
||||
"""
|
||||
return floor_p.bind(x)
|
||||
|
||||
@export
|
||||
def ceil(x: ArrayLike) -> Array:
|
||||
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`."""
|
||||
r"""Elementwise ceiling: :math:`\left\lceil x \right\rceil`.
|
||||
|
||||
This function lowers directly to the `stablehlo.ceil`_ operation.
|
||||
|
||||
Args:
|
||||
x: input array. Must be have floating-point type.
|
||||
|
||||
Returns:
|
||||
Array of same shape and dtype as ``x``, containing values rounded
|
||||
to the next integer toward positive infinity.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
|
||||
- :func:`jax.lax.round`: round to the nearest integer
|
||||
|
||||
.. _stablehlo.ceil: https://openxla.org/stablehlo/spec#ceil
|
||||
"""
|
||||
return ceil_p.bind(x)
|
||||
|
||||
class RoundingMethod(enum.IntEnum):
|
||||
@ -388,20 +422,38 @@ class RoundingMethod(enum.IntEnum):
|
||||
as “banker’s rounding” (e.g., 0.5 -> 0, 1.5 -> 2).
|
||||
"""
|
||||
|
||||
@export
|
||||
def round(x: ArrayLike,
|
||||
rounding_method: RoundingMethod = RoundingMethod.AWAY_FROM_ZERO
|
||||
) -> Array:
|
||||
r"""Elementwise round.
|
||||
|
||||
Rounds values to the nearest integer.
|
||||
Rounds values to the nearest integer. This function lowers directly to the
|
||||
`stablehlo.round`_ operation.
|
||||
|
||||
Args:
|
||||
x: an array or scalar value to round.
|
||||
x: an array or scalar value to round. Must have floating-point type.
|
||||
rounding_method: the method to use when rounding halfway values
|
||||
(e.g., `0.5`). See :class:`jax.lax.RoundingMethod` for possible values.
|
||||
(e.g., ``0.5``). See :class:`jax.lax.RoundingMethod` for possible values.
|
||||
|
||||
Returns:
|
||||
An array containing the elementwise rounding of x.
|
||||
An array of the same shape and dtype as ``x``, containing the elementwise
|
||||
rounding of ``x``.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.floor`: round to the next integer toward negative infinity
|
||||
- :func:`jax.lax.ceil`: round to the next integer toward positive infinity
|
||||
|
||||
Examples:
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax import lax
|
||||
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
|
||||
>>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO
|
||||
Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32)
|
||||
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
|
||||
Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)
|
||||
|
||||
.. _stablehlo.round: https://openxla.org/stablehlo/spec#round
|
||||
"""
|
||||
rounding_method = RoundingMethod(rounding_method)
|
||||
return round_p.bind(x, rounding_method=rounding_method)
|
||||
|
Loading…
x
Reference in New Issue
Block a user