From 8b46e53a4f8af705fc7218ec135acf95df9152b0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 18 Mar 2025 08:55:38 -0700 Subject: [PATCH] jax.lax: improve docs for several APIs --- jax/_src/lax/lax.py | 166 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 143 insertions(+), 23 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 186a915e0..86a75ada6 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array: """ return tanh_p.bind(x) +@export def logistic(x: ArrayLike) -> Array: - r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.""" + r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. + + There is no HLO logistic/sigmoid primitive, so this lowers to a sequence + of HLO arithmetic operations. + + Args: + x: input array. Must have floating point or complex dtype. + + Returns: + Array of the same shape and dtype as ``x`` containing the element-wise + logistic/sigmoid function. + + See also: + - :func:`jax.nn.sigmoid`: an alternative API for this functionality. + """ return logistic_p.bind(x) @export @@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: """ return xor_p.bind(x, y) +@export def population_count(x: ArrayLike) -> Array: - r"""Elementwise popcount, count the number of set bits in each element.""" + r"""Elementwise popcount, count the number of set bits in each element. + + This function lowers directly to the `stablehlo.popcnt`_ operation. + + Args: + x: Input array. Must have integer dtype. + + Returns: + An array of the same shape and dtype as ``x``, containing the number of + set bits in the input. + + See also: + - :func:`jax.lax.clz`: Elementwise count leading zeros. + - :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts. + + .. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt + """ return population_count_p.bind(x) +@export def clz(x: ArrayLike) -> Array: - r"""Elementwise count-leading-zeros.""" + r"""Elementwise count-leading-zeros. + + This function lowers directly to the `stablehlo.count_leading_zeros`_ operation. + + Args: + x: Input array. Must have integer dtype. + + Returns: + An array of the same shape and dtype as ``x``, containing the number of + set bits in the input. + + See also: + - :func:`jax.lax.population_count`: Count the number of set bits in each element. + + .. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros + """ return clz_p.bind(x) @export @@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: """ return div_p.bind(x, y) +@export def rem(x: ArrayLike, y: ArrayLike) -> Array: r"""Elementwise remainder: :math:`x \bmod y`. - The sign of the result is taken from the dividend, - and the absolute value of the result is always - less than the divisor's absolute value. + This function lowers directly to the `stablehlo.remainder`_ operation. + The sign of the result is taken from the dividend, and the absolute value + of the result is always less than the divisor's absolute value. - Integer division overflow - (remainder by zero or remainder of INT_SMIN with -1) + Integer division overflow (remainder by zero or remainder of INT_SMIN with -1) produces an implementation defined value. + + Args: + x, y: Input arrays. Must have matching int or float dtypes. If neither + is a scalar, ``x`` and ``y`` must have the same number of dimensions + and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the remainder. + + See also: + - :func:`jax.numpy.remainder`: NumPy-style remainder with different + sign semantics. + + .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ return rem_p.bind(x, y) +@export def max(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise maximum: :math:`\mathrm{max}(x, y)` + r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`. - For complex numbers, uses a lexicographic comparison on the - `(real, imaginary)` pairs.""" + This function lowers directly to the `stablehlo.maximum`_ operation for + non-complex inputs. For complex numbers, this uses a lexicographic + comparison on the `(real, imaginary)` pairs. + + Args: + x, y: Input arrays. Must have matching dtypes. If neither is a scalar, + ``x`` and ``y`` must have the same rank and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the elementwise + maximum. + + See also: + - :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum. + - :func:`jax.lax.reduce_max`: maximum along an axis of an array. + - :func:`jax.lax.min`: elementwise minimum. + + .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum + """ return max_p.bind(x, y) +@export def min(x: ArrayLike, y: ArrayLike) -> Array: - r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` + r"""Elementwise minimum: :math:`\mathrm{min}(x, y)` - For complex numbers, uses a lexicographic comparison on the - `(real, imaginary)` pairs.""" + This function lowers directly to the `stablehlo.minimum`_ operation for + non-complex inputs. For complex numbers, this uses a lexicographic + comparison on the `(real, imaginary)` pairs. + + Args: + x, y: Input arrays. Must have matching dtypes. If neither is a scalar, + ``x`` and ``y`` must have the same rank and be broadcast compatible. + + Returns: + An array of the same dtype as ``x`` and ``y`` containing the elementwise + minimum. + + See also: + - :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum. + - :func:`jax.lax.reduce_min`: minimum along an axis of an array. + - :func:`jax.lax.max`: elementwise maximum. + + .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum + """ return min_p.bind(x, y) @export @@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: """ return lt_p.bind(x, y) +@export def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array: """Elementwise cast. - Wraps XLA's `ConvertElementType - `_ - operator, which performs an elementwise conversion from one type to another. - Similar to a C++ `static_cast`. + This function lowers directly to the `stablehlo.convert`_ operation, which + performs an elementwise conversion from one type to another, similar to a + C++ ``static_cast``. Args: operand: an array or scalar value to be cast. - new_dtype: a NumPy dtype representing the target type. + new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type, + or a valid dtype name) representing the target dtype. Returns: - An array with the same shape as `operand`, cast elementwise to `new_dtype`. + An array with the same shape as ``operand``, cast elementwise to ``new_dtype``. + + .. note:: + + If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled, + the appropriate 32-bit type will be used in its place. + + If the input is a JAX array and the input dtype and output dtype match, then + the input array will be returned unmodified. + + See also: + - :func:`jax.numpy.astype`: NumPy-style dtype casting API. + - :meth:`jax.Array.astype`: dtype casting as an array method. + - :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype. + + .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert + .. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision """ return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type] @@ -1500,12 +1615,11 @@ def _convert_element_type( operand, new_dtype=new_dtype, weak_type=bool(weak_type), sharding=sharding) +@export def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: """Elementwise bitcast. - Wraps XLA's `BitcastConvertType - `_ - operator, which performs a bit cast from one type to another. + This function lowers directly to the `stablehlo.bitcast_convert`_ operation. The output shape depends on the size of the input and output dtypes with the following logic:: @@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: Returns: An array of shape `output_shape` (see above) and type `new_dtype`, constructed from the same bits as operand. + + See also: + - :func:`jax.lax.convert_element_type`: value-preserving dtype conversion. + - :func:`jax.Array.view`: NumPy-style API for bitcast type conversion. + + .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert """ new_dtype = dtypes.canonicalize_dtype(new_dtype) return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)