mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Improve documentation for a number of lax functions.
This commit is contained in:
parent
fb01f59020
commit
db43e21b1d
@ -64,6 +64,7 @@ Operators
|
||||
dynamic_index_in_dim
|
||||
dynamic_slice
|
||||
dynamic_slice_in_dim
|
||||
dynamic_update_slice
|
||||
dynamic_update_index_in_dim
|
||||
dynamic_update_slice_in_dim
|
||||
eq
|
||||
|
@ -691,9 +691,23 @@ def reshape(operand: Array, new_sizes: Shape,
|
||||
|
||||
def pad(operand: Array, padding_value: Array,
|
||||
padding_config: Sequence[Tuple[int, int, int]]) -> Array:
|
||||
"""Wraps XLA's `Pad
|
||||
"""Applies low, high, and/or interior padding to an array.
|
||||
|
||||
Wraps XLA's `Pad
|
||||
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
|
||||
operator.
|
||||
|
||||
Args:
|
||||
operand: an array to be padded.
|
||||
padding_value: the value to be inserted as padding. Must have the same dtype
|
||||
as ``operand``.
|
||||
padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
|
||||
giving the amount of low, high, and interior (dilation) padding to insert
|
||||
in each dimension.
|
||||
|
||||
Returns:
|
||||
The ``operand`` array with padding value ``padding_value`` inserted in each
|
||||
dimension according to the ``padding_config``.
|
||||
"""
|
||||
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
|
||||
|
||||
@ -1686,7 +1700,24 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
|
||||
return full(fill_shape, fill_value, dtype or _dtype(x))
|
||||
|
||||
|
||||
def collapse(operand: Array, start_dimension: int, stop_dimension: int) -> Array:
|
||||
def collapse(operand: Array, start_dimension: int,
|
||||
stop_dimension: int) -> Array:
|
||||
"""Collapses dimensions of an array into a single dimension.
|
||||
|
||||
For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
|
||||
``collapse(operand, 0, 2).shape == [6, 4]``. The elements of the collapsed
|
||||
dimension are laid out major-to-minor, i.e., with the lowest-numbered
|
||||
dimension as the slowest varying dimension.
|
||||
|
||||
Args:
|
||||
operand: an input array.
|
||||
start_dimension: the start of the dimensions to collapse (inclusive).
|
||||
stop_dimension: the end of the dimensions to collapse (exclusive).
|
||||
|
||||
Returns:
|
||||
An array where dimensions ``[start_dimension, stop_dimension)`` have been
|
||||
collapsed (raveled) into a single dimension.
|
||||
"""
|
||||
lo, hi = start_dimension, stop_dimension
|
||||
size = prod(operand.shape[lo:hi])
|
||||
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
||||
@ -1760,6 +1791,9 @@ def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
|
||||
|
||||
def dynamic_update_slice_in_dim(operand: Array, update: Array,
|
||||
start_index: Array, axis: int) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
|
||||
in a single ``axis``.
|
||||
"""
|
||||
axis = int(axis)
|
||||
start_indices = [_zero(start_index)] * _ndim(operand)
|
||||
start_indices[axis] = start_index
|
||||
@ -1768,6 +1802,9 @@ def dynamic_update_slice_in_dim(operand: Array, update: Array,
|
||||
|
||||
def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
|
||||
axis: int) -> Array:
|
||||
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
|
||||
of size 1 in a single ``axis``.
|
||||
"""
|
||||
axis = int(axis)
|
||||
if _ndim(update) != _ndim(operand):
|
||||
assert _ndim(update) + 1 == _ndim(operand)
|
||||
|
@ -2364,29 +2364,34 @@ def _interleave(a, b):
|
||||
return jnp.reshape(jnp.stack([a, b], axis=1),
|
||||
(2 * half_num_elems,) + a.shape[1:])
|
||||
|
||||
def associative_scan(fn, elems, reverse=False):
|
||||
"""Perform a scan with an associative binary operation, in parallel.
|
||||
def associative_scan(fn: Callable, elems, reverse: bool = False):
|
||||
"""Performs a scan with an associative binary operation, in parallel.
|
||||
|
||||
Args:
|
||||
fn: Python callable implementing an associative binary operation with
|
||||
fn: A Python callable implementing an associative binary operation with
|
||||
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
|
||||
must satisfy the equation
|
||||
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
|
||||
|
||||
signature ``r = fn(a, b)``. This must satisfy associativity:
|
||||
``fn(a, fn(b, c)) == fn(fn(a, b), c)``. The inputs and result are
|
||||
(possibly nested structures of) array(s) matching ``elems``. Each
|
||||
array has a leading dimension in place of ``num_elems``; the `fn`
|
||||
is expected to be scanned over this dimension. The result `r` has the same
|
||||
shape (and structure) as the two inputs ``a`` and ``b``.
|
||||
The inputs and result are (possibly nested Python tree structures of)
|
||||
array(s) matching ``elems``. Each array has a leading dimension in place
|
||||
of the ``num_elems`` dimension. `fn` should be applied elementwise over
|
||||
the leading dimension (for example, by using :func:`jax.vmap` over the
|
||||
elementwise function.)
|
||||
|
||||
The result `r` has the same shape (and structure) as the two inputs ``a``
|
||||
and ``b``.
|
||||
elems: A (possibly nested structure of) array(s), each with leading
|
||||
dimension ``num_elems``.
|
||||
reverse: A boolean stating if the scan should be reversed with respect to
|
||||
the leading dimension.
|
||||
|
||||
Returns:
|
||||
result: A (possibly nested structure of) array(s) of the same shape
|
||||
and structure as ``elems``, in which the ``k``th element is the result of
|
||||
recursively applying ``fn`` to combine the first ``k`` elements of
|
||||
``elems``. For example, given ``elems = [a, b, c, ...]``, the result
|
||||
would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
|
||||
A (possibly nested Python tree structure of) array(s) of the same shape
|
||||
and structure as ``elems``, in which the ``k``'th element is the result of
|
||||
recursively applying ``fn`` to combine the first ``k`` elements of
|
||||
``elems``. For example, given ``elems = [a, b, c, ...]``, the result
|
||||
would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
|
||||
|
||||
Example 1: partial sums of an array of numbers:
|
||||
|
||||
@ -2395,7 +2400,7 @@ def associative_scan(fn, elems, reverse=False):
|
||||
|
||||
Example 2: partial products of an array of matrices
|
||||
|
||||
>>> mats = random.uniform(random.PRNGKey(0), (4, 2, 2))
|
||||
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
|
||||
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
|
||||
>>> partial_prods.shape
|
||||
(4, 2, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user