Add a number of missing function cross-references in the docs.

This commit is contained in:
Peter Hawkins 2022-08-24 09:49:51 -04:00
parent e9e014f432
commit cd84eb10a6
14 changed files with 103 additions and 95 deletions

View File

@ -118,7 +118,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
) -> Callable:
"""Make ``fun`` recompute internal linearization points when differentiated.
The :func:`jax.checkpoint` decorator, aliased to ``jax.remat``, provides a
The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a
way to trade off computation time and memory cost in the context of automatic
differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
and :func:`jax.vjp` but also with :func:`jax.linearize`.
@ -153,10 +153,11 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
generated from differentiation. This CSE prevention has costs because it
can foil other optimizations, and because it can incur high overheads on
some backends, especially GPU. The default is True because otherwise,
under a ``jit`` or ``pmap``, CSE can defeat the purpose of this decorator.
But in some settings, like when used inside a ``scan``, this CSE
prevention mechanism is unnecessary, in which case ``prevent_cse`` can be
set to False.
under a :func:`~jax.jit` or :func:`~jax.pmap`, CSE can defeat the purpose
of this decorator.
But in some settings, like when used inside a :func:`~jax.lax.scan`, this
CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can
be set to False.
static_argnums: Optional, int or sequence of ints, a keyword-only argument
indicating which argument values on which to specialize for tracing and
caching purposes. Specifying arguments as static can avoid
@ -200,11 +201,11 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
While ``jax.checkpoint`` controls what values are stored from the forward-pass
to be used on the backward pass, the total amount of memory required to
evaluate a function or its VJP depends on many additional internal details of
that function. Those details include which numerical primitives are used,
how they're composed, where jit and control flow primitives like scan
While :func:`jax.checkpoint` controls what values are stored from the
forward-pass to be used on the backward pass, the total amount of memory
required to evaluate a function or its VJP depends on many additional internal
details of that function. Those details include which numerical primitives are
used, how they're composed, where jit and control flow primitives like scan
are used, and other factors.
The :func:`jax.checkpoint` decorator can be applied recursively to express
@ -253,7 +254,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
As an alternative to using ``static_argnums`` (and
``jax.ensure_compile_time_eval``), it may be easier to compute some values
outside the ``jax.checkpoint``-decorated function and then close over them.
outside the :func:`jax.checkpoint`-decorated function and then close over them.
"""
@wraps(fun)
@api_boundary

View File

@ -753,11 +753,11 @@ def disable_jit(disable: bool = True):
For debugging it is useful to have a mechanism that disables :py:func:`jit`
everywhere in a dynamic context. Note that this not only disables explicit
uses of `jit` by the user, but will also remove any implicit JIT compilation
uses of :func:`jit` by the user, but will also remove any implicit JIT compilation
used by the JAX library: this includes implicit JIT computation of `body` and
`cond` functions passed to higher-level primitives like :func:`scan` and
:func:`while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
and any other case where `jit` is used within an API's implementation.
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
and any other case where :func:`jit` is used within an API's implementation.
Values that have a data dependence on the arguments to a jitted function are
traced and abstracted. For example, an abstract value may be a
@ -2671,7 +2671,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
"""Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to ``vjp``, but
For linear functions, this transformation is equivalent to :py:func:`vjp`, but
avoids the overhead of computing the forward pass.
The outputs of the transposed function will always have the exact same dtypes
@ -3379,7 +3379,7 @@ def block_until_ready(x):
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
*args: Any, **kwargs: Any):
"""Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc.
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
@ -3388,12 +3388,12 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
is safe to be called multiple times (e.g. when transformed by ``vmap`` or
``pmap``), or not to be called at all when e.g. the output of a
is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
:func:`~pmap`), or not to be called at all when e.g. the output of a
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
When ``pmap``-ed, the pure callback will be called several times (one on each
When :func:`~pmap`-ed, the pure callback will be called several times (one on each
axis of the map). When `vmap`-ed the behavior will depend on the value of the
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
is assumed to obey

View File

@ -373,7 +373,7 @@ def check(pred: Bool, msg: str) -> None:
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, ``checkify`` it!
Before staging a function with checks, :func:`~checkify` it!
Args:
pred: if False, an error is added.
@ -407,7 +407,7 @@ def is_scalar_pred(pred) -> bool:
pred.dtype == jnp.dtype('bool'))
def check_error(error: Error) -> None:
"""Raise an Exception if ``error`` represents a failure. Functionalized by ``checkify``.
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
The semantics of this function are equivalent to:
@ -415,34 +415,38 @@ def check_error(error: Error) -> None:
... err.throw() # can raise ValueError
But unlike that implementation, ``check_error`` can be functionalized using
the ``checkify`` transformation.
the :func:`~checkify` transformation.
This function is similar to ``check`` but with a different signature: whereas
``check`` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both ``check``
This function is similar to :func:`~check` but with a different signature: whereas
:func:`~check` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both :func:`~check`
and this function raise a Python Exception on failure (a side-effect), and
thus cannot be staged out by ``jit``, ``pmap``, ``scan``, etc. Both also can
be functionalized by using ``checkify``.
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
:func:`~jax.lax.scan`, etc. Both also can
be functionalized by using :func:`~checkify`.
But unlike ``check``, this function is like a direct inverse of ``checkify``:
whereas ``checkify`` takes as input a function which can raise a Python
But unlike :func:`~check`, this function is like a direct inverse of
:func:`~checkify`:
whereas :func:`~checkify` takes as input a function which
can raise a Python
Exception and produces a new function without that effect but which produces
an ``Error`` value as output, this ``check_error`` function can accept an
``Error`` value as input and can produce the side-effect of raising an
Exception. That is, while ``checkify`` goes from functionalizable Exception
Exception. That is, while :func:`~checkify` goes from
functionalizable Exception
effect to error value, this ``check_error`` goes from error value to
functionalizable Exception effect.
``check_error`` is useful when you want to turn checks represented by an
``Error`` value (produced by functionalizing ``checks`` via ``checkify``)
back into Python Exceptions.
``Error`` value (produced by functionalizing ``checks`` via
:func:`~checkify`) back into Python Exceptions.
Args:
error: Error to check.
For example, you might want to functionalize part of your program through
checkify, stage out your functionalized code through ``jit``, then re-inject
your error value outside of the ``jit``:
checkify, stage out your functionalized code through :func:`~jax.jit`, then
re-inject your error value outside of the :func:`~jax.jit`:
>>> import jax
>>> from jax.experimental import checkify
@ -874,7 +878,7 @@ def checkify(fun: Callable[..., Out],
) -> Callable[..., Tuple[Error, Out]]:
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
Run-time errors are either user-added ``checkify.check`` assertions, or
Run-time errors are either user-added :func:`~check` assertions, or
automatically added checks like NaN checks, depending on the ``errors``
argument.
@ -884,11 +888,11 @@ def checkify(fun: Callable[..., Out],
will correspond to the first error which occurred. ``err.throw()`` will raise
a ValueError with the error message if an error occurred.
By default only user-added ``checkify.check`` assertions are enabled. You can
By default only user-added :func:`~check` assertions are enabled. You can
enable automatic checks through the ``errors`` argument.
The automatic check sets which can be enabled, and when an error is generated:
- ``user_checks``: a ``checkify.check`` evaluated to False.
- ``user_checks``: a :func:`~check` evaluated to False.
- ``nan_checks``: a floating-point operation generated a NaN value
as output.
- ``div_checks``: a division by zero.
@ -899,7 +903,7 @@ def checkify(fun: Callable[..., Out],
re-combined (eg. ``errors=float_checks|user_checks``)
Args:
fun: Callable which can contain user checks (see ``check``).
fun: Callable which can contain user checks (see :func:`~check`).
errors: A set of ErrorCategory values which defines the set of enabled
checks. By default only explicit ``checks`` are enabled
(``user_checks``). You can also for example enable NAN and
@ -909,7 +913,7 @@ def checkify(fun: Callable[..., Out],
Returns:
A function which accepts the same arguments as ``fun`` and returns as output
a pair where the first element is an ``Error`` value, representing the first
failed ``check``, and the second element is the original output of ``fun``.
failed :func:`~check`, and the second element is the original output of ``fun``.
For example:

View File

@ -430,9 +430,9 @@ class UnexpectedTracerError(JAXTypeError):
in an outer scope, return that value from the transformed function explictly.
Specifically, a ``Tracer`` is JAX's internal representation of a function's
intermediate values during transformations, e.g. within ``jit``, ``pmap``,
``vmap``, etc. Encountering a ``Tracer`` outside of a transformation implies a
leak.
intermediate values during transformations, e.g. within :func:`~jax.jit`,
:func:`~jax.pmap`, :func:`~jax.vmap`, etc. Encountering a ``Tracer`` outside
of a transformation implies a leak.
Life-cycle of a leaked value
Consider the following example of a transformed function which leaks a value
@ -460,7 +460,7 @@ class UnexpectedTracerError(JAXTypeError):
This example also demonstrates the life-cycle of a leaked value:
1. A function is transformed (in this case, by ``jit``)
1. A function is transformed (in this case, by :func:`~jax.jit`)
2. The transformed function is called (initiating an abstract trace of the
function and turning ``x`` into a ``Tracer``)
3. The intermediate value ``y``, which will later be leaked, is created
@ -473,7 +473,7 @@ class UnexpectedTracerError(JAXTypeError):
code by including information about each stage. Respectively:
1. The name of the transformed function (``side_effecting``) and which
transform kicked of the trace (``jit``).
transform kicked of the trace :func:`~jax.jit`).
2. A reconstructed stack trace of where the leaked Tracer was created,
which includes where the transformed function was called.
(``When the Tracer was created, the final 5 stack frames were...``).

View File

@ -128,7 +128,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
each with an additional leading axis.
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
of ``scan`` are given roughly by this Python implementation::
of :func:`~scan` are given roughly by this Python implementation::
def scan(f, init, xs, length=None):
if xs is None:
@ -144,10 +144,11 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
types, and so multiple arrays can be scanned over at once and produce multiple
output arrays. (None is actually an empty pytree.)
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
a single XLA While HLO. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
lowered to a single XLA While HLO. That makes it useful for reducing
compilation times for JIT-compiled functions, since native Python
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
XLA computations.
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
across all iterations (and not just be consistent up to NumPy rank/shape
@ -1632,7 +1633,7 @@ def fori_loop(lower, upper, body_fun, init_val):
call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the
trip count is static (meaning known at tracing time, perhaps because ``lower``
and ``upper`` are Python integer literals) then the ``fori_loop`` is
implemented in terms of ``scan`` and reverse-mode autodiff is supported;
implemented in terms of :func:`~scan` and reverse-mode autodiff is supported;
otherwise, a ``while_loop`` is used and reverse-mode autodiff is not
supported. See those functions' docstrings for more information.
@ -1702,19 +1703,20 @@ def map(f, xs):
"""Map a function over leading array axes.
Like Python's builtin map, except inputs and outputs are in the form of
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
need to apply a function element by element for reduced memory usage or
heterogeneous computation with other control flow primitives.
When ``xs`` is an array type, the semantics of ``map`` are given by this
When ``xs`` is an array type, the semantics of :func:`~map` are given by this
Python implementation::
def map(f, xs):
return np.stack([f(x) for x in xs])
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
nested pytree type, and the mapped computation is compiled only once.
Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
many of the same advantages over a Python loop apply: ``xs`` may be an
arbitrary nested pytree type, and the mapped computation is compiled only
once.
Args:
f: a Python function to apply element-wise over the first axis or axes of

View File

@ -209,7 +209,7 @@ def segment_sum(data: Array,
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_sum`` in a ``jit``-compiled function.
must be provided to use ``segment_sum`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_sum`` is
@ -265,7 +265,7 @@ def segment_prod(data: Array,
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_prod`` in a ``jit``-compiled function.
must be provided to use ``segment_prod`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_prod`` is
@ -321,7 +321,7 @@ def segment_max(data: Array,
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_max`` in a ``jit``-compiled function.
must be provided to use ``segment_max`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_max`` is
@ -376,7 +376,7 @@ def segment_min(data: Array,
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_min`` in a ``jit``-compiled function.
must be provided to use ``segment_min`` in a JIT-compiled function.
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
unique_indices: whether `segment_ids` is known to be free of duplicates.
bucket_size: size of bucket to group indices into. ``segment_min`` is

View File

@ -84,14 +84,14 @@ def start_trace(log_dir, create_perfetto_link: bool = False,
"""Starts a profiler trace.
The trace will capture CPU, GPU, and/or TPU activity, including Python
functions and JAX on-device operations. Use ``stop_trace()`` to end the trace
functions and JAX on-device operations. Use :func:`stop_trace` to end the trace
and save the results to ``log_dir``.
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.
Only once trace may be collected a time. A RuntimeError will be raised if
``start_trace()`` is called while another trace is running.
:func:`start_trace` is called while another trace is running.
Args:
log_dir: The directory to save the profiler trace to (usually the
@ -188,7 +188,7 @@ def stop_trace():
"""Stops the currently-running profiler trace.
The trace will be saved to the ``log_dir`` passed to the corresponding
``start_trace()`` call. Raises a RuntimeError if a trace hasn't been started.
:func:`start_trace` call. Raises a RuntimeError if a trace hasn't been started.
"""
with _profile_state.lock:
if _profile_state.profile_session is None:

View File

@ -66,7 +66,7 @@ def minimize(
- Optimization results may differ from SciPy due to differences in the line
search implementation.
``minimize`` supports ``jit`` compilation. It does not yet support
``minimize`` supports :func:`~jax.jit` compilation. It does not yet support
differentiation or arguments in the form of multi-dimensional arrays, but
support for both is planned.

View File

@ -483,7 +483,7 @@ class Lowered(Stage):
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
"""
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]

View File

@ -140,19 +140,19 @@ def api_boundary(fun: C) -> C:
stack trace of the original exception, but with JAX-internal frames removed.
This boundary annotation works in composition with itself. The topmost frame
corresponding to an ``api_boundary`` is the one below which stack traces are
filtered. In other words, if ``api_boundary(f)`` calls ``api_boundary(g)``,
directly or indirectly, the filtered stack trace provided is the same as if
``api_boundary(f)`` were to simply call ``g`` instead.
corresponding to an :func:`~api_boundary` is the one below which stack traces
are filtered. In other words, if ``api_boundary(f)`` calls
``api_boundary(g)``, directly or indirectly, the filtered stack trace provided
is the same as if ``api_boundary(f)`` were to simply call ``g`` instead.
This annotation is primarily useful in wrapping functions output by JAX's
transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is
called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
in order to trace and translate it. If the function ``f`` raises an exception,
the stack unwinds through JAX's JIT internals up to the original call site of
``g``. Because the function returned by ``jax.jit`` is annotated as an
``api_boundary``, such an exception is accompanied by an additional traceback
that excludes the frames specific to JAX's implementation.
``g``. Because the function returned by :func:`~jax.jit` is annotated as an
:func:`~api_boundary`, such an exception is accompanied by an additional
traceback that excludes the frames specific to JAX's implementation.
'''
@util.wraps(fun)

View File

@ -1019,15 +1019,15 @@ def new_base_main(trace_type: Type[Trace],
def ensure_compile_time_eval():
"""Context manager to ensure evaluation at trace/compile time (or error).
Some JAX APIs like ``jax.jit`` and ``jax.lax.scan`` involve staging, i.e.
delaying the evaluation of numerical expressions (like jax.numpy function
applications) so that instead of performing those computations eagerly while
evaluating the corresponding Python expressions, their computation is carried
out separately, e.g. after optimized compilation. But this delay can be
undesirable. For example, numerical values might be needed to evaluate Python
control flow and so their evaluation cannot be delayed. As another example, it
may be beneficial to ensure compile time evaluation (or "constant folding")
for performance reasons.
Some JAX APIs like :func:`jax.jit`` and :func:`jax.lax.scan` involve staging,
i.e., delaying the evaluation of numerical expressions (like :mod:`jax.numpy`
function applications) so that instead of performing those computations
eagerly while evaluating the corresponding Python expressions, their
computation is carried out separately, e.g. after optimized compilation. But
this delay can be undesirable. For example, numerical values might be needed
to evaluate Python control flow and so their evaluation cannot be delayed. As
another example, it may be beneficial to ensure compile time evaluation (or
"constant folding") for performance reasons.
This context manager ensures that JAX computations are evaluated eagerly. If
eager evaluation is not possible, a ``ConcretizationError`` is raised.

View File

@ -457,8 +457,8 @@ receiver, it may be hard to debug the calls. In particular, the stack trace
will not include the calling code. You can use the flag
``jax_host_callback_inline`` (or the environment variable
``JAX_HOST_CALLBACK_INLINE``) to ensure that the calls to the callbacks are
inlined. This works only if the calls are outside a staging context (``jit``
or a control-flow primitive).
inlined. This works only if the calls are outside a staging context
(:func:`~jax.jit` or a control-flow primitive).
The C++ `receiver
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_

View File

@ -232,22 +232,23 @@ def pjit(fun: Callable,
the output partitioning specified in ``out_axis_resources``. The resources
specified in those two arguments must refer to mesh axes, as defined by
the :py:func:`jax.experimental.maps.Mesh` context manager. Note that the mesh
definition at ``pjit`` application time is ignored, and the returned function
definition at :func:`~pjit` application time is ignored, and the returned function
will use the mesh definition available at each call site.
Inputs to a pjit'd function will be automatically partitioned across devices
Inputs to a :func:`~pjit`'d function will be automatically partitioned across devices
if they're not already correctly partitioned based on ``in_axis_resources``.
In some scenarios, ensuring that the inputs are already correctly pre-partitioned
can increase performance. For example, if passing the output of one pjit'd function
to another pjitd function (or the same pjitd function in a loop), make sure the
relevant ``out_axis_resources`` match the corresponding ``in_axis_resources``.
can increase performance. For example, if passing the output of one
:func:`~pjit`'d function to another :func:`~pjit`d function (or the same
:func:`~pjit`d function in a loop), make sure the relevant
``out_axis_resources`` match the corresponding ``in_axis_resources``.
.. note::
**Multi-process platforms:** On multi-process platforms such as TPU pods,
``pjit`` can be used to run computations across all available devices across
processes. To achieve this, ``pjit`` is designed to be used in SPMD Python
:func:`~pjit` can be used to run computations across all available devices across
processes. To achieve this, :func:`~pjit` is designed to be used in SPMD Python
programs, where every process is running the same Python code such that all
processes run the same pjit'd function in the same order.
processes run the same :func:`~pjit`'d function in the same order.
When running in this configuration, the mesh should contain devices across
all processes. However, any input argument dimensions partitioned over
@ -256,7 +257,7 @@ def pjit(fun: Callable,
mesh. ``fun`` will still be executed across *all* devices in the mesh,
including those from other processes, and will be given a global view of the
data spread across multiple processes as a single array. However, outside
of ``pjit`` every process only "sees" its local piece of the input and output,
of :func:`~pjit` every process only "sees" its local piece of the input and output,
corresponding to its local sub-mesh.
This means that each process's participating local devices must form a
@ -264,7 +265,7 @@ def pjit(fun: Callable,
sub-mesh is one where all of its devices are adjacent within the global
mesh, and form a rectangular prism.
The SPMD model also requires that the same multi-process ``pjit``'d
The SPMD model also requires that the same multi-process :func:`~pjit`'d
functions must be run in the same order on all processes, but they can be
interspersed with arbitrary operations running in a single process.
@ -320,7 +321,7 @@ def pjit(fun: Callable,
automaticly partitioned by the mesh available at each call site.
For example, a convolution operator can be automatically partitioned over
an arbitrary set of devices by a single ``pjit`` application:
an arbitrary set of devices by a single :func:`~pjit` application:
>>> import jax
>>> import jax.numpy as jnp

View File

@ -14,7 +14,7 @@
"""Utilities for pseudo-random number generation.
The ``jax.random`` package provides a number of routines for deterministic
The :mod:`jax.random` package provides a number of routines for deterministic
generation of sequences of pseudorandom numbers.
Basic usage