Improve documentation for a number of lax functions.

This commit is contained in:
Peter Hawkins 2020-10-14 21:18:09 -04:00
parent fb01f59020
commit db43e21b1d
3 changed files with 60 additions and 17 deletions

View File

@ -64,6 +64,7 @@ Operators
dynamic_index_in_dim
dynamic_slice
dynamic_slice_in_dim
dynamic_update_slice
dynamic_update_index_in_dim
dynamic_update_slice_in_dim
eq

View File

@ -691,9 +691,23 @@ def reshape(operand: Array, new_sizes: Shape,
def pad(operand: Array, padding_value: Array,
padding_config: Sequence[Tuple[int, int, int]]) -> Array:
"""Wraps XLA's `Pad
"""Applies low, high, and/or interior padding to an array.
Wraps XLA's `Pad
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
operator.
Args:
operand: an array to be padded.
padding_value: the value to be inserted as padding. Must have the same dtype
as ``operand``.
padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
giving the amount of low, high, and interior (dilation) padding to insert
in each dimension.
Returns:
The ``operand`` array with padding value ``padding_value`` inserted in each
dimension according to the ``padding_config``.
"""
return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config))
@ -1686,7 +1700,24 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
return full(fill_shape, fill_value, dtype or _dtype(x))
def collapse(operand: Array, start_dimension: int, stop_dimension: int) -> Array:
def collapse(operand: Array, start_dimension: int,
stop_dimension: int) -> Array:
"""Collapses dimensions of an array into a single dimension.
For example, if ``operand`` is an array with shape ``[2, 3, 4]``,
``collapse(operand, 0, 2).shape == [6, 4]``. The elements of the collapsed
dimension are laid out major-to-minor, i.e., with the lowest-numbered
dimension as the slowest varying dimension.
Args:
operand: an input array.
start_dimension: the start of the dimensions to collapse (inclusive).
stop_dimension: the end of the dimensions to collapse (exclusive).
Returns:
An array where dimensions ``[start_dimension, stop_dimension)`` have been
collapsed (raveled) into a single dimension.
"""
lo, hi = start_dimension, stop_dimension
size = prod(operand.shape[lo:hi])
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
@ -1760,6 +1791,9 @@ def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
def dynamic_update_slice_in_dim(operand: Array, update: Array,
start_index: Array, axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
in a single ``axis``.
"""
axis = int(axis)
start_indices = [_zero(start_index)] * _ndim(operand)
start_indices[axis] = start_index
@ -1768,6 +1802,9 @@ def dynamic_update_slice_in_dim(operand: Array, update: Array,
def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
of size 1 in a single ``axis``.
"""
axis = int(axis)
if _ndim(update) != _ndim(operand):
assert _ndim(update) + 1 == _ndim(operand)

View File

@ -2364,29 +2364,34 @@ def _interleave(a, b):
return jnp.reshape(jnp.stack([a, b], axis=1),
(2 * half_num_elems,) + a.shape[1:])
def associative_scan(fn, elems, reverse=False):
"""Perform a scan with an associative binary operation, in parallel.
def associative_scan(fn: Callable, elems, reverse: bool = False):
"""Performs a scan with an associative binary operation, in parallel.
Args:
fn: Python callable implementing an associative binary operation with
fn: A Python callable implementing an associative binary operation with
signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it
must satisfy the equation
``fn(a, fn(b, c)) == fn(fn(a, b), c)``.
signature ``r = fn(a, b)``. This must satisfy associativity:
``fn(a, fn(b, c)) == fn(fn(a, b), c)``. The inputs and result are
(possibly nested structures of) array(s) matching ``elems``. Each
array has a leading dimension in place of ``num_elems``; the `fn`
is expected to be scanned over this dimension. The result `r` has the same
shape (and structure) as the two inputs ``a`` and ``b``.
The inputs and result are (possibly nested Python tree structures of)
array(s) matching ``elems``. Each array has a leading dimension in place
of the ``num_elems`` dimension. `fn` should be applied elementwise over
the leading dimension (for example, by using :func:`jax.vmap` over the
elementwise function.)
The result `r` has the same shape (and structure) as the two inputs ``a``
and ``b``.
elems: A (possibly nested structure of) array(s), each with leading
dimension ``num_elems``.
reverse: A boolean stating if the scan should be reversed with respect to
the leading dimension.
Returns:
result: A (possibly nested structure of) array(s) of the same shape
and structure as ``elems``, in which the ``k``th element is the result of
recursively applying ``fn`` to combine the first ``k`` elements of
``elems``. For example, given ``elems = [a, b, c, ...]``, the result
would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
A (possibly nested Python tree structure of) array(s) of the same shape
and structure as ``elems``, in which the ``k``'th element is the result of
recursively applying ``fn`` to combine the first ``k`` elements of
``elems``. For example, given ``elems = [a, b, c, ...]``, the result
would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``.
Example 1: partial sums of an array of numbers:
@ -2395,7 +2400,7 @@ def associative_scan(fn, elems, reverse=False):
Example 2: partial products of an array of matrices
>>> mats = random.uniform(random.PRNGKey(0), (4, 2, 2))
>>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)