mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax.lax: improve docs for comparison operators
This commit is contained in:
parent
1dc58b79bf
commit
7f115fbb64
@ -1148,28 +1148,184 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return shift_right_logical_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def eq(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise equals: :math:`x = y`."""
|
||||
r"""Elementwise equals: :math:`x = y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.compare`_ operation
|
||||
with ``comparison_direction=EQ`` and ``compare_type`` set according
|
||||
to the input dtype.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching dtypes. If neither is a
|
||||
scalar, ``x`` and ``y`` must have the same number of dimensions and
|
||||
be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
||||
containing the elementwise equal comparison.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.equal`: NumPy wrapper for this API, also accessible
|
||||
via the ``x == y`` operator on JAX arrays.
|
||||
- :func:`jax.lax.ne`: elementwise not-equal
|
||||
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
||||
- :func:`jax.lax.gt`: elementwise greater-than
|
||||
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
||||
- :func:`jax.lax.lt`: elementwise less-than
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
return eq_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def ne(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise not-equals: :math:`x \neq y`."""
|
||||
r"""Elementwise not-equals: :math:`x \neq y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.compare`_ operation
|
||||
with ``comparison_direction=NE`` and ``compare_type`` set according
|
||||
to the input dtype.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching dtypes. If neither is a
|
||||
scalar, ``x`` and ``y`` must have the same number of dimensions and
|
||||
be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
||||
containing the elementwise not-equal comparison.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.not_equal`: NumPy wrapper for this API, also accessible
|
||||
via the ``x != y`` operator on JAX arrays.
|
||||
- :func:`jax.lax.eq`: elementwise equal
|
||||
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
||||
- :func:`jax.lax.gt`: elementwise greater-than
|
||||
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
||||
- :func:`jax.lax.lt`: elementwise less-than
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
return ne_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def ge(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise greater-than-or-equals: :math:`x \geq y`."""
|
||||
r"""Elementwise greater-than-or-equals: :math:`x \geq y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.compare`_ operation
|
||||
with ``comparison_direction=GE`` and ``compare_type`` set according
|
||||
to the input dtype.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
||||
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
||||
be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
||||
containing the elementwise greater-than-or-equal comparison.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.greater_equal`: NumPy wrapper for this API, also
|
||||
accessible via the ``x >= y`` operator on JAX arrays.
|
||||
- :func:`jax.lax.eq`: elementwise equal
|
||||
- :func:`jax.lax.ne`: elementwise not-equal
|
||||
- :func:`jax.lax.gt`: elementwise greater-than
|
||||
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
||||
- :func:`jax.lax.lt`: elementwise less-than
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
return ge_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def gt(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise greater-than: :math:`x > y`."""
|
||||
r"""Elementwise greater-than: :math:`x > y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.compare`_ operation
|
||||
with ``comparison_direction=GT`` and ``compare_type`` set according
|
||||
to the input dtype.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
||||
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
||||
be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
||||
containing the elementwise greater-than comparison.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.greater`: NumPy wrapper for this API, also accessible
|
||||
via the ``x > y`` operator on JAX arrays.
|
||||
- :func:`jax.lax.eq`: elementwise equal
|
||||
- :func:`jax.lax.ne`: elementwise not-equal
|
||||
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
||||
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
||||
- :func:`jax.lax.lt`: elementwise less-than
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
return gt_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def le(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise less-than-or-equals: :math:`x \leq y`."""
|
||||
r"""Elementwise less-than-or-equals: :math:`x \leq y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.compare`_ operation
|
||||
with ``comparison_direction=LE`` and ``compare_type`` set according
|
||||
to the input dtype.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
||||
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
||||
be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
||||
containing the elementwise less-than-or-equal comparison.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.less_equal`: NumPy wrapper for this API, also
|
||||
accessible via the ``x <= y`` operator on JAX arrays.
|
||||
- :func:`jax.lax.eq`: elementwise equal
|
||||
- :func:`jax.lax.ne`: elementwise not-equal
|
||||
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
||||
- :func:`jax.lax.gt`: elementwise greater-than
|
||||
- :func:`jax.lax.lt`: elementwise less-than
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
return le_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise less-than: :math:`x < y`."""
|
||||
r"""Elementwise less-than: :math:`x < y`.
|
||||
|
||||
This function lowers directly to the `stablehlo.compare`_ operation
|
||||
with ``comparison_direction=LT`` and ``compare_type`` set according
|
||||
to the input dtype.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching non-complex dtypes. If neither is
|
||||
a scalar, ``x`` and ``y`` must have the same number of dimensions and
|
||||
be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
A boolean array of shape ``lax.broadcast_shapes(x.shape, y.shape)``
|
||||
containing the elementwise less-than comparison.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.less`: NumPy wrapper for this API, also accessible
|
||||
via the ``x < y`` operator on JAX arrays.
|
||||
- :func:`jax.lax.eq`: elementwise equal
|
||||
- :func:`jax.lax.ne`: elementwise not-equal
|
||||
- :func:`jax.lax.ge`: elementwise greater-than-or-equal
|
||||
- :func:`jax.lax.gt`: elementwise greater-than
|
||||
- :func:`jax.lax.le`: elementwise less-than-or-equal
|
||||
|
||||
.. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare
|
||||
"""
|
||||
return lt_p.bind(x, y)
|
||||
|
||||
def convert_element_type(operand: ArrayLike,
|
||||
|
Loading…
x
Reference in New Issue
Block a user