mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #27198 from jakevdp:lax-docs
PiperOrigin-RevId: 738038116
This commit is contained in:
commit
30941480a1
@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
|
||||
"""
|
||||
return tanh_p.bind(x)
|
||||
|
||||
@export
|
||||
def logistic(x: ArrayLike) -> Array:
|
||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
|
||||
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
|
||||
|
||||
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
|
||||
of HLO arithmetic operations.
|
||||
|
||||
Args:
|
||||
x: input array. Must have floating point or complex dtype.
|
||||
|
||||
Returns:
|
||||
Array of the same shape and dtype as ``x`` containing the element-wise
|
||||
logistic/sigmoid function.
|
||||
|
||||
See also:
|
||||
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
|
||||
"""
|
||||
return logistic_p.bind(x)
|
||||
|
||||
@export
|
||||
@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return xor_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def population_count(x: ArrayLike) -> Array:
|
||||
r"""Elementwise popcount, count the number of set bits in each element."""
|
||||
r"""Elementwise popcount, count the number of set bits in each element.
|
||||
|
||||
This function lowers directly to the `stablehlo.popcnt`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have integer dtype.
|
||||
|
||||
Returns:
|
||||
An array of the same shape and dtype as ``x``, containing the number of
|
||||
set bits in the input.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.clz`: Elementwise count leading zeros.
|
||||
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
|
||||
|
||||
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
|
||||
"""
|
||||
return population_count_p.bind(x)
|
||||
|
||||
@export
|
||||
def clz(x: ArrayLike) -> Array:
|
||||
r"""Elementwise count-leading-zeros."""
|
||||
r"""Elementwise count-leading-zeros.
|
||||
|
||||
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
|
||||
|
||||
Args:
|
||||
x: Input array. Must have integer dtype.
|
||||
|
||||
Returns:
|
||||
An array of the same shape and dtype as ``x``, containing the number of
|
||||
set bits in the input.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
|
||||
|
||||
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
|
||||
"""
|
||||
return clz_p.bind(x)
|
||||
|
||||
@export
|
||||
@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return div_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def rem(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise remainder: :math:`x \bmod y`.
|
||||
|
||||
The sign of the result is taken from the dividend,
|
||||
and the absolute value of the result is always
|
||||
less than the divisor's absolute value.
|
||||
This function lowers directly to the `stablehlo.remainder`_ operation.
|
||||
The sign of the result is taken from the dividend, and the absolute value
|
||||
of the result is always less than the divisor's absolute value.
|
||||
|
||||
Integer division overflow
|
||||
(remainder by zero or remainder of INT_SMIN with -1)
|
||||
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
|
||||
produces an implementation defined value.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching int or float dtypes. If neither
|
||||
is a scalar, ``x`` and ``y`` must have the same number of dimensions
|
||||
and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the remainder.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
|
||||
sign semantics.
|
||||
|
||||
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
|
||||
"""
|
||||
return rem_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def max(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
|
||||
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
|
||||
|
||||
For complex numbers, uses a lexicographic comparison on the
|
||||
`(real, imaginary)` pairs."""
|
||||
This function lowers directly to the `stablehlo.maximum`_ operation for
|
||||
non-complex inputs. For complex numbers, this uses a lexicographic
|
||||
comparison on the `(real, imaginary)` pairs.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
||||
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
||||
maximum.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
|
||||
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
|
||||
- :func:`jax.lax.min`: elementwise minimum.
|
||||
|
||||
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
|
||||
"""
|
||||
return max_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def min(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
||||
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
|
||||
|
||||
For complex numbers, uses a lexicographic comparison on the
|
||||
`(real, imaginary)` pairs."""
|
||||
This function lowers directly to the `stablehlo.minimum`_ operation for
|
||||
non-complex inputs. For complex numbers, this uses a lexicographic
|
||||
comparison on the `(real, imaginary)` pairs.
|
||||
|
||||
Args:
|
||||
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
|
||||
``x`` and ``y`` must have the same rank and be broadcast compatible.
|
||||
|
||||
Returns:
|
||||
An array of the same dtype as ``x`` and ``y`` containing the elementwise
|
||||
minimum.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
|
||||
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
|
||||
- :func:`jax.lax.max`: elementwise maximum.
|
||||
|
||||
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
|
||||
"""
|
||||
return min_p.bind(x, y)
|
||||
|
||||
@export
|
||||
@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
"""
|
||||
return lt_p.bind(x, y)
|
||||
|
||||
@export
|
||||
def convert_element_type(operand: ArrayLike,
|
||||
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
|
||||
"""Elementwise cast.
|
||||
|
||||
Wraps XLA's `ConvertElementType
|
||||
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
|
||||
operator, which performs an elementwise conversion from one type to another.
|
||||
Similar to a C++ `static_cast`.
|
||||
This function lowers directly to the `stablehlo.convert`_ operation, which
|
||||
performs an elementwise conversion from one type to another, similar to a
|
||||
C++ ``static_cast``.
|
||||
|
||||
Args:
|
||||
operand: an array or scalar value to be cast.
|
||||
new_dtype: a NumPy dtype representing the target type.
|
||||
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
|
||||
or a valid dtype name) representing the target dtype.
|
||||
|
||||
Returns:
|
||||
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
|
||||
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
|
||||
|
||||
.. note::
|
||||
|
||||
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
|
||||
the appropriate 32-bit type will be used in its place.
|
||||
|
||||
If the input is a JAX array and the input dtype and output dtype match, then
|
||||
the input array will be returned unmodified.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
|
||||
- :meth:`jax.Array.astype`: dtype casting as an array method.
|
||||
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
|
||||
|
||||
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
|
||||
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
|
||||
"""
|
||||
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
|
||||
|
||||
@ -1500,12 +1615,11 @@ def _convert_element_type(
|
||||
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
||||
sharding=sharding)
|
||||
|
||||
@export
|
||||
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
"""Elementwise bitcast.
|
||||
|
||||
Wraps XLA's `BitcastConvertType
|
||||
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
|
||||
operator, which performs a bit cast from one type to another.
|
||||
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
|
||||
|
||||
The output shape depends on the size of the input and output dtypes with
|
||||
the following logic::
|
||||
@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
Returns:
|
||||
An array of shape `output_shape` (see above) and type `new_dtype`,
|
||||
constructed from the same bits as operand.
|
||||
|
||||
See also:
|
||||
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
|
||||
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
|
||||
|
||||
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
|
||||
"""
|
||||
new_dtype = dtypes.canonicalize_dtype(new_dtype)
|
||||
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user