jax.lax: improve docs for comparison operators

This commit is contained in:
Jake VanderPlas 2025-02-18 13:48:59 -08:00
parent 1dc58b79bf
commit 7f115fbb64

View File

@ -1148,28 +1148,184 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
"""
return shift_right_logical_p.bind(x, y)
@export
def eq(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise equals: :math:`x = y`."""
r"""Elementwise equals: :math:`x = y`.
This function lowers directly to the `stablehlo.compare`_ operation
with ``comparison_direction=EQ`` and ``compare_type`` set according
to the input dtype.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a
scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
containing the elementwise equal comparison.
See also:
- :func:`jax.numpy.equal`: NumPy wrapper for this API, also accessible
via the ``x == y`` operator on JAX arrays.
- :func:`jax.lax.ne`: elementwise not-equal
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
- :func:`jax.lax.gt`: elementwise greater-than
- :func:`jax.lax.le`: elementwise less-than-or-equal
- :func:`jax.lax.lt`: elementwise less-than
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
return eq_p.bind(x, y)
@export
def ne(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise not-equals: :math:`x \neq y`."""
r"""Elementwise not-equals: :math:`x \neq y`.
This function lowers directly to the `stablehlo.compare`_ operation
with ``comparison_direction=NE`` and ``compare_type`` set according
to the input dtype.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a
scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
containing the elementwise not-equal comparison.
See also:
- :func:`jax.numpy.not_equal`: NumPy wrapper for this API, also accessible
via the ``x != y`` operator on JAX arrays.
- :func:`jax.lax.eq`: elementwise equal
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
- :func:`jax.lax.gt`: elementwise greater-than
- :func:`jax.lax.le`: elementwise less-than-or-equal
- :func:`jax.lax.lt`: elementwise less-than
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
return ne_p.bind(x, y)
@export
def ge(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
r"""Elementwise greater-than-or-equals: :math:`x \geq y`.
This function lowers directly to the `stablehlo.compare`_ operation
with ``comparison_direction=GE`` and ``compare_type`` set according
to the input dtype.
Args:
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
a scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
containing the elementwise greater-than-or-equal comparison.
See also:
- :func:`jax.numpy.greater_equal`: NumPy wrapper for this API, also
accessible via the ``x >= y`` operator on JAX arrays.
- :func:`jax.lax.eq`: elementwise equal
- :func:`jax.lax.ne`: elementwise not-equal
- :func:`jax.lax.gt`: elementwise greater-than
- :func:`jax.lax.le`: elementwise less-than-or-equal
- :func:`jax.lax.lt`: elementwise less-than
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
return ge_p.bind(x, y)
@export
def gt(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise greater-than: :math:`x > y`."""
r"""Elementwise greater-than: :math:`x > y`.
This function lowers directly to the `stablehlo.compare`_ operation
with ``comparison_direction=GT`` and ``compare_type`` set according
to the input dtype.
Args:
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
a scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
containing the elementwise greater-than comparison.
See also:
- :func:`jax.numpy.greater`: NumPy wrapper for this API, also accessible
via the ``x > y`` operator on JAX arrays.
- :func:`jax.lax.eq`: elementwise equal
- :func:`jax.lax.ne`: elementwise not-equal
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
- :func:`jax.lax.le`: elementwise less-than-or-equal
- :func:`jax.lax.lt`: elementwise less-than
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
return gt_p.bind(x, y)
@export
def le(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
r"""Elementwise less-than-or-equals: :math:`x \leq y`.
This function lowers directly to the `stablehlo.compare`_ operation
with ``comparison_direction=LE`` and ``compare_type`` set according
to the input dtype.
Args:
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
a scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
containing the elementwise less-than-or-equal comparison.
See also:
- :func:`jax.numpy.less_equal`: NumPy wrapper for this API, also
accessible via the ``x <= y`` operator on JAX arrays.
- :func:`jax.lax.eq`: elementwise equal
- :func:`jax.lax.ne`: elementwise not-equal
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
- :func:`jax.lax.gt`: elementwise greater-than
- :func:`jax.lax.lt`: elementwise less-than
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
return le_p.bind(x, y)
@export
def lt(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise less-than: :math:`x < y`."""
r"""Elementwise less-than: :math:`x < y`.
This function lowers directly to the `stablehlo.compare`_ operation
with ``comparison_direction=LT`` and ``compare_type`` set according
to the input dtype.
Args:
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
a scalar, ``x`` and ``y`` must have the same number of dimensions and
be broadcast compatible.
Returns:
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
containing the elementwise less-than comparison.
See also:
- :func:`jax.numpy.less`: NumPy wrapper for this API, also accessible
via the ``x < y`` operator on JAX arrays.
- :func:`jax.lax.eq`: elementwise equal
- :func:`jax.lax.ne`: elementwise not-equal
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
- :func:`jax.lax.gt`: elementwise greater-than
- :func:`jax.lax.le`: elementwise less-than-or-equal
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
"""
return lt_p.bind(x, y)
def convert_element_type(operand: ArrayLike,