jax.lax: improve docs for several APIs

This commit is contained in:
Jake VanderPlas 2025-03-18 08:55:38 -07:00
parent 1e36cbe597
commit 8b46e53a4f

View File

@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
""" """
return tanh_p.bind(x) return tanh_p.bind(x)
@export
def logistic(x: ArrayLike) -> Array: def logistic(x: ArrayLike) -> Array:
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.""" r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
of HLO arithmetic operations.
Args:
x: input array. Must have floating point or complex dtype.
Returns:
Array of the same shape and dtype as ``x`` containing the element-wise
logistic/sigmoid function.
See also:
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
"""
return logistic_p.bind(x) return logistic_p.bind(x)
@export @export
@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
""" """
return xor_p.bind(x, y) return xor_p.bind(x, y)
@export
def population_count(x: ArrayLike) -> Array: def population_count(x: ArrayLike) -> Array:
r"""Elementwise popcount, count the number of set bits in each element.""" r"""Elementwise popcount, count the number of set bits in each element.
This function lowers directly to the `stablehlo.popcnt`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.clz`: Elementwise count leading zeros.
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
"""
return population_count_p.bind(x) return population_count_p.bind(x)
@export
def clz(x: ArrayLike) -> Array: def clz(x: ArrayLike) -> Array:
r"""Elementwise count-leading-zeros.""" r"""Elementwise count-leading-zeros.
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
Args:
x: Input array. Must have integer dtype.
Returns:
An array of the same shape and dtype as ``x``, containing the number of
set bits in the input.
See also:
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
"""
return clz_p.bind(x) return clz_p.bind(x)
@export @export
@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
""" """
return div_p.bind(x, y) return div_p.bind(x, y)
@export
def rem(x: ArrayLike, y: ArrayLike) -> Array: def rem(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise remainder: :math:`x \bmod y`. r"""Elementwise remainder: :math:`x \bmod y`.
The sign of the result is taken from the dividend, This function lowers directly to the `stablehlo.remainder`_ operation.
and the absolute value of the result is always The sign of the result is taken from the dividend, and the absolute value
less than the divisor's absolute value. of the result is always less than the divisor's absolute value.
Integer division overflow Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
(remainder by zero or remainder of INT_SMIN with -1)
produces an implementation defined value. produces an implementation defined value.
Args:
x, y: Input arrays. Must have matching int or float dtypes. If neither
is a scalar, ``x`` and ``y`` must have the same number of dimensions
and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the remainder.
See also:
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
sign semantics.
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
""" """
return rem_p.bind(x, y) return rem_p.bind(x, y)
@export
def max(x: ArrayLike, y: ArrayLike) -> Array: def max(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)` r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
For complex numbers, uses a lexicographic comparison on the This function lowers directly to the `stablehlo.maximum`_ operation for
`(real, imaginary)` pairs.""" non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
maximum.
See also:
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
- :func:`jax.lax.min`: elementwise minimum.
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
"""
return max_p.bind(x, y) return max_p.bind(x, y)
@export
def min(x: ArrayLike, y: ArrayLike) -> Array: def min(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
For complex numbers, uses a lexicographic comparison on the This function lowers directly to the `stablehlo.minimum`_ operation for
`(real, imaginary)` pairs.""" non-complex inputs. For complex numbers, this uses a lexicographic
comparison on the `(real, imaginary)` pairs.
Args:
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
``x`` and ``y`` must have the same rank and be broadcast compatible.
Returns:
An array of the same dtype as ``x`` and ``y`` containing the elementwise
minimum.
See also:
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
- :func:`jax.lax.max`: elementwise maximum.
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
"""
return min_p.bind(x, y) return min_p.bind(x, y)
@export @export
@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
""" """
return lt_p.bind(x, y) return lt_p.bind(x, y)
@export
def convert_element_type(operand: ArrayLike, def convert_element_type(operand: ArrayLike,
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array: new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
"""Elementwise cast. """Elementwise cast.
Wraps XLA's `ConvertElementType This function lowers directly to the `stablehlo.convert`_ operation, which
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_ performs an elementwise conversion from one type to another, similar to a
operator, which performs an elementwise conversion from one type to another. C++ ``static_cast``.
Similar to a C++ `static_cast`.
Args: Args:
operand: an array or scalar value to be cast. operand: an array or scalar value to be cast.
new_dtype: a NumPy dtype representing the target type. new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
or a valid dtype name) representing the target dtype.
Returns: Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`. An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
.. note::
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
the appropriate 32-bit type will be used in its place.
If the input is a JAX array and the input dtype and output dtype match, then
the input array will be returned unmodified.
See also:
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
- :meth:`jax.Array.astype`: dtype casting as an array method.
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
""" """
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type] return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
@ -1500,12 +1615,11 @@ def _convert_element_type(
operand, new_dtype=new_dtype, weak_type=bool(weak_type), operand, new_dtype=new_dtype, weak_type=bool(weak_type),
sharding=sharding) sharding=sharding)
@export
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise bitcast. """Elementwise bitcast.
Wraps XLA's `BitcastConvertType This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
operator, which performs a bit cast from one type to another.
The output shape depends on the size of the input and output dtypes with The output shape depends on the size of the input and output dtypes with
the following logic:: the following logic::
@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
Returns: Returns:
An array of shape `output_shape` (see above) and type `new_dtype`, An array of shape `output_shape` (see above) and type `new_dtype`,
constructed from the same bits as operand. constructed from the same bits as operand.
See also:
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
""" """
new_dtype = dtypes.canonicalize_dtype(new_dtype) new_dtype = dtypes.canonicalize_dtype(new_dtype)
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)