mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
jax.lax: improve docs for several APIs
This commit is contained in:
parent
1e36cbe597
commit
8b46e53a4f
@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
|
|||||||
"""
|
"""
|
||||||
return tanh_p.bind(x)
|
return tanh_p.bind(x)
|
||||||
|
|
||||||
|
@export
|
||||||
def logistic(x: ArrayLike) -> Array:
|
def logistic(x: ArrayLike) -> Array:
|
||||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
|
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
|
||||||
|
|
||||||
|
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
|
||||||
|
of HLO arithmetic operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: input array. Must have floating point or complex dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||||
|
logistic/sigmoid function.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
|
||||||
|
"""
|
||||||
return logistic_p.bind(x)
|
return logistic_p.bind(x)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
|
|||||||
"""
|
"""
|
||||||
return xor_p.bind(x, y)
|
return xor_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def population_count(x: ArrayLike) -> Array:
|
def population_count(x: ArrayLike) -> Array:
|
||||||
r"""Elementwise popcount, count the number of set bits in each element."""
|
r"""Elementwise popcount, count the number of set bits in each element.
|
||||||
|
|
||||||
|
This function lowers directly to the `stablehlo.popcnt`_ operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input array. Must have integer dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same shape and dtype as ``x``, containing the number of
|
||||||
|
set bits in the input.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.lax.clz`: Elementwise count leading zeros.
|
||||||
|
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
|
||||||
|
|
||||||
|
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
|
||||||
|
"""
|
||||||
return population_count_p.bind(x)
|
return population_count_p.bind(x)
|
||||||
|
|
||||||
|
@export
|
||||||
def clz(x: ArrayLike) -> Array:
|
def clz(x: ArrayLike) -> Array:
|
||||||
r"""Elementwise count-leading-zeros."""
|
r"""Elementwise count-leading-zeros.
|
||||||
|
|
||||||
|
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input array. Must have integer dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same shape and dtype as ``x``, containing the number of
|
||||||
|
set bits in the input.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
|
||||||
|
|
||||||
|
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
|
||||||
|
"""
|
||||||
return clz_p.bind(x)
|
return clz_p.bind(x)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
|
|||||||
"""
|
"""
|
||||||
return div_p.bind(x, y)
|
return div_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def rem(x: ArrayLike, y: ArrayLike) -> Array:
|
def rem(x: ArrayLike, y: ArrayLike) -> Array:
|
||||||
r"""Elementwise remainder: :math:`x \bmod y`.
|
r"""Elementwise remainder: :math:`x \bmod y`.
|
||||||
|
|
||||||
The sign of the result is taken from the dividend,
|
This function lowers directly to the `stablehlo.remainder`_ operation.
|
||||||
and the absolute value of the result is always
|
The sign of the result is taken from the dividend, and the absolute value
|
||||||
less than the divisor's absolute value.
|
of the result is always less than the divisor's absolute value.
|
||||||
|
|
||||||
Integer division overflow
|
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
|
||||||
(remainder by zero or remainder of INT_SMIN with -1)
|
|
||||||
produces an implementation defined value.
|
produces an implementation defined value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x, y: Input arrays. Must have matching int or float dtypes. If neither
|
||||||
|
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||||
|
and be broadcast compatible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same dtype as ``x`` and ``y`` containing the remainder.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
|
||||||
|
sign semantics.
|
||||||
|
|
||||||
|
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
|
||||||
"""
|
"""
|
||||||
return rem_p.bind(x, y)
|
return rem_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def max(x: ArrayLike, y: ArrayLike) -> Array:
|
def max(x: ArrayLike, y: ArrayLike) -> Array:
|
||||||
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
|
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
|
||||||
|
|
||||||
For complex numbers, uses a lexicographic comparison on the
|
This function lowers directly to the `stablehlo.maximum`_ operation for
|
||||||
`(real, imaginary)` pairs."""
|
non-complex inputs. For complex numbers, this uses a lexicographic
|
||||||
|
comparison on the `(real, imaginary)` pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
||||||
|
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
||||||
|
maximum.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
|
||||||
|
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
|
||||||
|
- :func:`jax.lax.min`: elementwise minimum.
|
||||||
|
|
||||||
|
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
|
||||||
|
"""
|
||||||
return max_p.bind(x, y)
|
return max_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def min(x: ArrayLike, y: ArrayLike) -> Array:
|
def min(x: ArrayLike, y: ArrayLike) -> Array:
|
||||||
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
||||||
|
|
||||||
For complex numbers, uses a lexicographic comparison on the
|
This function lowers directly to the `stablehlo.minimum`_ operation for
|
||||||
`(real, imaginary)` pairs."""
|
non-complex inputs. For complex numbers, this uses a lexicographic
|
||||||
|
comparison on the `(real, imaginary)` pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
||||||
|
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
||||||
|
minimum.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
|
||||||
|
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
|
||||||
|
- :func:`jax.lax.max`: elementwise maximum.
|
||||||
|
|
||||||
|
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
|
||||||
|
"""
|
||||||
return min_p.bind(x, y)
|
return min_p.bind(x, y)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
|||||||
"""
|
"""
|
||||||
return lt_p.bind(x, y)
|
return lt_p.bind(x, y)
|
||||||
|
|
||||||
|
@export
|
||||||
def convert_element_type(operand: ArrayLike,
|
def convert_element_type(operand: ArrayLike,
|
||||||
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
|
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
|
||||||
"""Elementwise cast.
|
"""Elementwise cast.
|
||||||
|
|
||||||
Wraps XLA's `ConvertElementType
|
This function lowers directly to the `stablehlo.convert`_ operation, which
|
||||||
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
|
performs an elementwise conversion from one type to another, similar to a
|
||||||
operator, which performs an elementwise conversion from one type to another.
|
C++ ``static_cast``.
|
||||||
Similar to a C++ `static_cast`.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
operand: an array or scalar value to be cast.
|
operand: an array or scalar value to be cast.
|
||||||
new_dtype: a NumPy dtype representing the target type.
|
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
|
||||||
|
or a valid dtype name) representing the target dtype.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
|
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
|
||||||
|
the appropriate 32-bit type will be used in its place.
|
||||||
|
|
||||||
|
If the input is a JAX array and the input dtype and output dtype match, then
|
||||||
|
the input array will be returned unmodified.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
|
||||||
|
- :meth:`jax.Array.astype`: dtype casting as an array method.
|
||||||
|
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
|
||||||
|
|
||||||
|
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
|
||||||
|
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
|
||||||
"""
|
"""
|
||||||
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
|
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
|
||||||
|
|
||||||
@ -1500,12 +1615,11 @@ def _convert_element_type(
|
|||||||
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
|
||||||
|
@export
|
||||||
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||||
"""Elementwise bitcast.
|
"""Elementwise bitcast.
|
||||||
|
|
||||||
Wraps XLA's `BitcastConvertType
|
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
|
||||||
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
|
|
||||||
operator, which performs a bit cast from one type to another.
|
|
||||||
|
|
||||||
The output shape depends on the size of the input and output dtypes with
|
The output shape depends on the size of the input and output dtypes with
|
||||||
the following logic::
|
the following logic::
|
||||||
@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
|||||||
Returns:
|
Returns:
|
||||||
An array of shape `output_shape` (see above) and type `new_dtype`,
|
An array of shape `output_shape` (see above) and type `new_dtype`,
|
||||||
constructed from the same bits as operand.
|
constructed from the same bits as operand.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
|
||||||
|
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
|
||||||
|
|
||||||
|
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
|
||||||
"""
|
"""
|
||||||
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
||||||
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user