Documentation fixes. (#3282)

Improve some cross-references and poorly quoted text.
This commit is contained in:
Peter Hawkins 2020-06-01 18:09:45 -04:00 committed by GitHub
parent 858f1e5465
commit cf624196ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 16 deletions

View File

@ -391,7 +391,7 @@ def axis_index(axis_name):
Args:
axis_name: hashable Python object used to name the pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
Returns:
An integer representing the index.

View File

@ -1320,10 +1320,10 @@ def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
def stop_gradient(x):
"""Stops gradient computation.
Operationally `stop_gradient` is the identity function, that is, it returns
argument `x` unchanged. However, `stop_gradient` prevents the flow of
Operationally ``stop_gradient`` is the identity function, that is, it returns
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
gradients during forward or reverse-mode automatic differentiation. If there
are multiple nested gradient computations, `stop_gradient` stops gradients
are multiple nested gradient computations, ``stop_gradient`` stops gradients
for all of them.
For example:

View File

@ -147,7 +147,7 @@ def _fori_scan_body_fun(body_fun):
return scanned_fun
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to ``while_loop``.
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
The type signature in brief is
@ -164,7 +164,8 @@ def fori_loop(lower, upper, body_fun, init_val):
return val
Unlike that Python version, ``fori_loop`` is implemented in terms of a call to
``while_loop``. See the docstring for ``while_loop`` for more information.
:func:`jax.lax.while_loop`. See the :func:`jax.lax.while_loop` documentation
for more information.
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
fixed shape and dtype across all iterations (and not just be consistent up to

View File

@ -49,7 +49,7 @@ def psum(x, axis_name, *, axis_index_groups=None):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
two and last two replicas). Groups must cover all axis indices exactly
@ -87,7 +87,7 @@ def pmean(x, axis_name, *, axis_index_groups=None):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first
two and last two replicas). Groups must cover all axis indices exactly
@ -119,7 +119,7 @@ def pmax(x, axis_name, *, axis_index_groups=None):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first
two and last two replicas). Groups must cover all axis indices exactly
@ -142,7 +142,7 @@ def pmin(x, axis_name, *, axis_index_groups=None):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
axis_index_groups: optional list of lists containing axis indices (e.g. for
an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first
two and last two replicas). Groups must cover all axis indices exactly
@ -177,8 +177,9 @@ def ppermute(x, axis_name, perm):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
perm: list of pairs of ints, representing (source_index, destination_index)
:func:`jax.pmap` documentation for more details).
perm: list of pairs of ints, representing
``(source_index, destination_index)``
pairs that encode how the mapped axis named ``axis_name`` should be
shuffled. The integer values are treated as indices into the mapped axis
``axis_name``. Any two pairs should not have the same source index or the
@ -204,7 +205,7 @@ def pshuffle(x, axis_name, perm):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
perm: list of of ints, representing the new order of the source indices
that encode how the mapped axis named ``axis_name`` should be
shuffled. The integer values are treated as indices into the mapped axis
@ -236,7 +237,7 @@ def pswapaxes(x, axis_name, axis):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
axis: int indicating the unmapped axis of ``x`` to map with the name
``axis_name``.
@ -263,7 +264,7 @@ def all_to_all(x, axis_name, split_axis, concat_axis):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
split_axis: int indicating the unmapped axis of ``x`` to map with the name
``axis_name``.
concat_axis: int indicating the position in the output to materialize the
@ -271,6 +272,7 @@ def all_to_all(x, axis_name, split_axis, concat_axis):
Returns:
Array(s) with shape given by the expression::
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
where ``axis_size`` is the size of the mapped axis named ``axis_name`` in
@ -476,7 +478,7 @@ def all_gather(x, axis_name):
Args:
x: array(s) with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
:func:`jax.pmap` documentation for more details).
Returns:
Array(s) representing the result of an all-gather along the axis