mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Documentation fixes. (#3282)
Improve some cross-references and poorly quoted text.
This commit is contained in:
parent
858f1e5465
commit
cf624196ed
@ -391,7 +391,7 @@ def axis_index(axis_name):
|
||||
|
||||
Args:
|
||||
axis_name: hashable Python object used to name the pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
|
||||
Returns:
|
||||
An integer representing the index.
|
||||
|
@ -1320,10 +1320,10 @@ def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
|
||||
def stop_gradient(x):
|
||||
"""Stops gradient computation.
|
||||
|
||||
Operationally `stop_gradient` is the identity function, that is, it returns
|
||||
argument `x` unchanged. However, `stop_gradient` prevents the flow of
|
||||
Operationally ``stop_gradient`` is the identity function, that is, it returns
|
||||
argument `x` unchanged. However, ``stop_gradient`` prevents the flow of
|
||||
gradients during forward or reverse-mode automatic differentiation. If there
|
||||
are multiple nested gradient computations, `stop_gradient` stops gradients
|
||||
are multiple nested gradient computations, ``stop_gradient`` stops gradients
|
||||
for all of them.
|
||||
|
||||
For example:
|
||||
|
@ -147,7 +147,7 @@ def _fori_scan_body_fun(body_fun):
|
||||
return scanned_fun
|
||||
|
||||
def fori_loop(lower, upper, body_fun, init_val):
|
||||
"""Loop from ``lower`` to ``upper`` by reduction to ``while_loop``.
|
||||
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
|
||||
|
||||
The type signature in brief is
|
||||
|
||||
@ -164,7 +164,8 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
return val
|
||||
|
||||
Unlike that Python version, ``fori_loop`` is implemented in terms of a call to
|
||||
``while_loop``. See the docstring for ``while_loop`` for more information.
|
||||
:func:`jax.lax.while_loop`. See the :func:`jax.lax.while_loop` documentation
|
||||
for more information.
|
||||
|
||||
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
||||
fixed shape and dtype across all iterations (and not just be consistent up to
|
||||
|
@ -49,7 +49,7 @@ def psum(x, axis_name, *, axis_index_groups=None):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
||||
an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
|
||||
two and last two replicas). Groups must cover all axis indices exactly
|
||||
@ -87,7 +87,7 @@ def pmean(x, axis_name, *, axis_index_groups=None):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
||||
an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first
|
||||
two and last two replicas). Groups must cover all axis indices exactly
|
||||
@ -119,7 +119,7 @@ def pmax(x, axis_name, *, axis_index_groups=None):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
||||
an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first
|
||||
two and last two replicas). Groups must cover all axis indices exactly
|
||||
@ -142,7 +142,7 @@ def pmin(x, axis_name, *, axis_index_groups=None):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
||||
an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first
|
||||
two and last two replicas). Groups must cover all axis indices exactly
|
||||
@ -177,8 +177,9 @@ def ppermute(x, axis_name, perm):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
perm: list of pairs of ints, representing (source_index, destination_index)
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
perm: list of pairs of ints, representing
|
||||
``(source_index, destination_index)``
|
||||
pairs that encode how the mapped axis named ``axis_name`` should be
|
||||
shuffled. The integer values are treated as indices into the mapped axis
|
||||
``axis_name``. Any two pairs should not have the same source index or the
|
||||
@ -204,7 +205,7 @@ def pshuffle(x, axis_name, perm):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
perm: list of of ints, representing the new order of the source indices
|
||||
that encode how the mapped axis named ``axis_name`` should be
|
||||
shuffled. The integer values are treated as indices into the mapped axis
|
||||
@ -236,7 +237,7 @@ def pswapaxes(x, axis_name, axis):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
axis: int indicating the unmapped axis of ``x`` to map with the name
|
||||
``axis_name``.
|
||||
|
||||
@ -263,7 +264,7 @@ def all_to_all(x, axis_name, split_axis, concat_axis):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
split_axis: int indicating the unmapped axis of ``x`` to map with the name
|
||||
``axis_name``.
|
||||
concat_axis: int indicating the position in the output to materialize the
|
||||
@ -271,6 +272,7 @@ def all_to_all(x, axis_name, split_axis, concat_axis):
|
||||
|
||||
Returns:
|
||||
Array(s) with shape given by the expression::
|
||||
|
||||
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
|
||||
|
||||
where ``axis_size`` is the size of the mapped axis named ``axis_name`` in
|
||||
@ -476,7 +478,7 @@ def all_gather(x, axis_name):
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
:func:`jax.pmap` documentation for more details).
|
||||
|
||||
Returns:
|
||||
Array(s) representing the result of an all-gather along the axis
|
||||
|
Loading…
x
Reference in New Issue
Block a user