mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add a number of missing function cross-references in the docs.
This commit is contained in:
parent
e9e014f432
commit
cd84eb10a6
@ -118,7 +118,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
) -> Callable:
|
||||
"""Make ``fun`` recompute internal linearization points when differentiated.
|
||||
|
||||
The :func:`jax.checkpoint` decorator, aliased to ``jax.remat``, provides a
|
||||
The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a
|
||||
way to trade off computation time and memory cost in the context of automatic
|
||||
differentiation, especially with reverse-mode autodiff like :func:`jax.grad`
|
||||
and :func:`jax.vjp` but also with :func:`jax.linearize`.
|
||||
@ -153,10 +153,11 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
generated from differentiation. This CSE prevention has costs because it
|
||||
can foil other optimizations, and because it can incur high overheads on
|
||||
some backends, especially GPU. The default is True because otherwise,
|
||||
under a ``jit`` or ``pmap``, CSE can defeat the purpose of this decorator.
|
||||
But in some settings, like when used inside a ``scan``, this CSE
|
||||
prevention mechanism is unnecessary, in which case ``prevent_cse`` can be
|
||||
set to False.
|
||||
under a :func:`~jax.jit` or :func:`~jax.pmap`, CSE can defeat the purpose
|
||||
of this decorator.
|
||||
But in some settings, like when used inside a :func:`~jax.lax.scan`, this
|
||||
CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can
|
||||
be set to False.
|
||||
static_argnums: Optional, int or sequence of ints, a keyword-only argument
|
||||
indicating which argument values on which to specialize for tracing and
|
||||
caching purposes. Specifying arguments as static can avoid
|
||||
@ -200,11 +201,11 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
At that time, the value ``jnp.sin(2.0)`` is recomputed, along with the values
|
||||
``jnp.cos(2.0)`` and ``jnp.cos(jnp.sin(2.0))``.
|
||||
|
||||
While ``jax.checkpoint`` controls what values are stored from the forward-pass
|
||||
to be used on the backward pass, the total amount of memory required to
|
||||
evaluate a function or its VJP depends on many additional internal details of
|
||||
that function. Those details include which numerical primitives are used,
|
||||
how they're composed, where jit and control flow primitives like scan
|
||||
While :func:`jax.checkpoint` controls what values are stored from the
|
||||
forward-pass to be used on the backward pass, the total amount of memory
|
||||
required to evaluate a function or its VJP depends on many additional internal
|
||||
details of that function. Those details include which numerical primitives are
|
||||
used, how they're composed, where jit and control flow primitives like scan
|
||||
are used, and other factors.
|
||||
|
||||
The :func:`jax.checkpoint` decorator can be applied recursively to express
|
||||
@ -253,7 +254,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
|
||||
As an alternative to using ``static_argnums`` (and
|
||||
``jax.ensure_compile_time_eval``), it may be easier to compute some values
|
||||
outside the ``jax.checkpoint``-decorated function and then close over them.
|
||||
outside the :func:`jax.checkpoint`-decorated function and then close over them.
|
||||
"""
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
|
@ -753,11 +753,11 @@ def disable_jit(disable: bool = True):
|
||||
|
||||
For debugging it is useful to have a mechanism that disables :py:func:`jit`
|
||||
everywhere in a dynamic context. Note that this not only disables explicit
|
||||
uses of `jit` by the user, but will also remove any implicit JIT compilation
|
||||
uses of :func:`jit` by the user, but will also remove any implicit JIT compilation
|
||||
used by the JAX library: this includes implicit JIT computation of `body` and
|
||||
`cond` functions passed to higher-level primitives like :func:`scan` and
|
||||
:func:`while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
|
||||
and any other case where `jit` is used within an API's implementation.
|
||||
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
|
||||
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
|
||||
and any other case where :func:`jit` is used within an API's implementation.
|
||||
|
||||
Values that have a data dependence on the arguments to a jitted function are
|
||||
traced and abstracted. For example, an abstract value may be a
|
||||
@ -2671,7 +2671,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
|
||||
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
"""Transpose a function that is promised to be linear.
|
||||
|
||||
For linear functions, this transformation is equivalent to ``vjp``, but
|
||||
For linear functions, this transformation is equivalent to :py:func:`vjp`, but
|
||||
avoids the overhead of computing the forward pass.
|
||||
|
||||
The outputs of the transposed function will always have the exact same dtypes
|
||||
@ -3379,7 +3379,7 @@ def block_until_ready(x):
|
||||
|
||||
def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
*args: Any, **kwargs: Any):
|
||||
"""Applies a functionally pure Python callable. Works under `jit`/`pmap`/etc.
|
||||
"""Applies a functionally pure Python callable. Works under :func:`jit`/:func:`~pmap`/etc.
|
||||
|
||||
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
|
||||
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
|
||||
@ -3388,12 +3388,12 @@ def pure_callback(callback: Callable[..., Any], result_shape_dtypes: Any,
|
||||
|
||||
The callback is treated as functionally pure, meaning it has no side-effects
|
||||
and its output value depends only on its argument values. As a consequence, it
|
||||
is safe to be called multiple times (e.g. when transformed by ``vmap`` or
|
||||
``pmap``), or not to be called at all when e.g. the output of a
|
||||
is safe to be called multiple times (e.g. when transformed by :func:`~vmap` or
|
||||
:func:`~pmap`), or not to be called at all when e.g. the output of a
|
||||
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
||||
may also be reordered if data-dependence allows.
|
||||
|
||||
When ``pmap``-ed, the pure callback will be called several times (one on each
|
||||
When :func:`~pmap`-ed, the pure callback will be called several times (one on each
|
||||
axis of the map). When `vmap`-ed the behavior will depend on the value of the
|
||||
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
|
||||
is assumed to obey
|
||||
|
@ -373,7 +373,7 @@ def check(pred: Bool, msg: str) -> None:
|
||||
"""Check a predicate, add an error with msg if predicate is False.
|
||||
|
||||
This is an effectful operation, and can't be staged (jitted/scanned/...).
|
||||
Before staging a function with checks, ``checkify`` it!
|
||||
Before staging a function with checks, :func:`~checkify` it!
|
||||
|
||||
Args:
|
||||
pred: if False, an error is added.
|
||||
@ -407,7 +407,7 @@ def is_scalar_pred(pred) -> bool:
|
||||
pred.dtype == jnp.dtype('bool'))
|
||||
|
||||
def check_error(error: Error) -> None:
|
||||
"""Raise an Exception if ``error`` represents a failure. Functionalized by ``checkify``.
|
||||
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
|
||||
|
||||
The semantics of this function are equivalent to:
|
||||
|
||||
@ -415,34 +415,38 @@ def check_error(error: Error) -> None:
|
||||
... err.throw() # can raise ValueError
|
||||
|
||||
But unlike that implementation, ``check_error`` can be functionalized using
|
||||
the ``checkify`` transformation.
|
||||
the :func:`~checkify` transformation.
|
||||
|
||||
This function is similar to ``check`` but with a different signature: whereas
|
||||
``check`` takes as arguments a boolean predicate and a new error message
|
||||
string, this function takes an ``Error`` value as argument. Both ``check``
|
||||
This function is similar to :func:`~check` but with a different signature: whereas
|
||||
:func:`~check` takes as arguments a boolean predicate and a new error message
|
||||
string, this function takes an ``Error`` value as argument. Both :func:`~check`
|
||||
and this function raise a Python Exception on failure (a side-effect), and
|
||||
thus cannot be staged out by ``jit``, ``pmap``, ``scan``, etc. Both also can
|
||||
be functionalized by using ``checkify``.
|
||||
thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
|
||||
:func:`~jax.lax.scan`, etc. Both also can
|
||||
be functionalized by using :func:`~checkify`.
|
||||
|
||||
But unlike ``check``, this function is like a direct inverse of ``checkify``:
|
||||
whereas ``checkify`` takes as input a function which can raise a Python
|
||||
But unlike :func:`~check`, this function is like a direct inverse of
|
||||
:func:`~checkify`:
|
||||
whereas :func:`~checkify` takes as input a function which
|
||||
can raise a Python
|
||||
Exception and produces a new function without that effect but which produces
|
||||
an ``Error`` value as output, this ``check_error`` function can accept an
|
||||
``Error`` value as input and can produce the side-effect of raising an
|
||||
Exception. That is, while ``checkify`` goes from functionalizable Exception
|
||||
Exception. That is, while :func:`~checkify` goes from
|
||||
functionalizable Exception
|
||||
effect to error value, this ``check_error`` goes from error value to
|
||||
functionalizable Exception effect.
|
||||
|
||||
``check_error`` is useful when you want to turn checks represented by an
|
||||
``Error`` value (produced by functionalizing ``checks`` via ``checkify``)
|
||||
back into Python Exceptions.
|
||||
``Error`` value (produced by functionalizing ``checks`` via
|
||||
:func:`~checkify`) back into Python Exceptions.
|
||||
|
||||
Args:
|
||||
error: Error to check.
|
||||
|
||||
For example, you might want to functionalize part of your program through
|
||||
checkify, stage out your functionalized code through ``jit``, then re-inject
|
||||
your error value outside of the ``jit``:
|
||||
checkify, stage out your functionalized code through :func:`~jax.jit`, then
|
||||
re-inject your error value outside of the :func:`~jax.jit`:
|
||||
|
||||
>>> import jax
|
||||
>>> from jax.experimental import checkify
|
||||
@ -874,7 +878,7 @@ def checkify(fun: Callable[..., Out],
|
||||
) -> Callable[..., Tuple[Error, Out]]:
|
||||
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
|
||||
|
||||
Run-time errors are either user-added ``checkify.check`` assertions, or
|
||||
Run-time errors are either user-added :func:`~check` assertions, or
|
||||
automatically added checks like NaN checks, depending on the ``errors``
|
||||
argument.
|
||||
|
||||
@ -884,11 +888,11 @@ def checkify(fun: Callable[..., Out],
|
||||
will correspond to the first error which occurred. ``err.throw()`` will raise
|
||||
a ValueError with the error message if an error occurred.
|
||||
|
||||
By default only user-added ``checkify.check`` assertions are enabled. You can
|
||||
By default only user-added :func:`~check` assertions are enabled. You can
|
||||
enable automatic checks through the ``errors`` argument.
|
||||
|
||||
The automatic check sets which can be enabled, and when an error is generated:
|
||||
- ``user_checks``: a ``checkify.check`` evaluated to False.
|
||||
- ``user_checks``: a :func:`~check` evaluated to False.
|
||||
- ``nan_checks``: a floating-point operation generated a NaN value
|
||||
as output.
|
||||
- ``div_checks``: a division by zero.
|
||||
@ -899,7 +903,7 @@ def checkify(fun: Callable[..., Out],
|
||||
re-combined (eg. ``errors=float_checks|user_checks``)
|
||||
|
||||
Args:
|
||||
fun: Callable which can contain user checks (see ``check``).
|
||||
fun: Callable which can contain user checks (see :func:`~check`).
|
||||
errors: A set of ErrorCategory values which defines the set of enabled
|
||||
checks. By default only explicit ``checks`` are enabled
|
||||
(``user_checks``). You can also for example enable NAN and
|
||||
@ -909,7 +913,7 @@ def checkify(fun: Callable[..., Out],
|
||||
Returns:
|
||||
A function which accepts the same arguments as ``fun`` and returns as output
|
||||
a pair where the first element is an ``Error`` value, representing the first
|
||||
failed ``check``, and the second element is the original output of ``fun``.
|
||||
failed :func:`~check`, and the second element is the original output of ``fun``.
|
||||
|
||||
For example:
|
||||
|
||||
|
@ -430,9 +430,9 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
in an outer scope, return that value from the transformed function explictly.
|
||||
|
||||
Specifically, a ``Tracer`` is JAX's internal representation of a function's
|
||||
intermediate values during transformations, e.g. within ``jit``, ``pmap``,
|
||||
``vmap``, etc. Encountering a ``Tracer`` outside of a transformation implies a
|
||||
leak.
|
||||
intermediate values during transformations, e.g. within :func:`~jax.jit`,
|
||||
:func:`~jax.pmap`, :func:`~jax.vmap`, etc. Encountering a ``Tracer`` outside
|
||||
of a transformation implies a leak.
|
||||
|
||||
Life-cycle of a leaked value
|
||||
Consider the following example of a transformed function which leaks a value
|
||||
@ -460,7 +460,7 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
|
||||
This example also demonstrates the life-cycle of a leaked value:
|
||||
|
||||
1. A function is transformed (in this case, by ``jit``)
|
||||
1. A function is transformed (in this case, by :func:`~jax.jit`)
|
||||
2. The transformed function is called (initiating an abstract trace of the
|
||||
function and turning ``x`` into a ``Tracer``)
|
||||
3. The intermediate value ``y``, which will later be leaked, is created
|
||||
@ -473,7 +473,7 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
code by including information about each stage. Respectively:
|
||||
|
||||
1. The name of the transformed function (``side_effecting``) and which
|
||||
transform kicked of the trace (``jit``).
|
||||
transform kicked of the trace :func:`~jax.jit`).
|
||||
2. A reconstructed stack trace of where the leaked Tracer was created,
|
||||
which includes where the transformed function was called.
|
||||
(``When the Tracer was created, the final 5 stack frames were...``).
|
||||
|
@ -128,7 +128,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
each with an additional leading axis.
|
||||
|
||||
When ``a`` is an array type or None, and ``b`` is an array type, the semantics
|
||||
of ``scan`` are given roughly by this Python implementation::
|
||||
of :func:`~scan` are given roughly by this Python implementation::
|
||||
|
||||
def scan(f, init, xs, length=None):
|
||||
if xs is None:
|
||||
@ -144,10 +144,11 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
types, and so multiple arrays can be scanned over at once and produce multiple
|
||||
output arrays. (None is actually an empty pytree.)
|
||||
|
||||
Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
|
||||
a single XLA While HLO. That makes it useful for reducing compilation times
|
||||
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
||||
function are unrolled, leading to large XLA computations.
|
||||
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
|
||||
lowered to a single XLA While HLO. That makes it useful for reducing
|
||||
compilation times for JIT-compiled functions, since native Python
|
||||
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
|
||||
XLA computations.
|
||||
|
||||
Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype
|
||||
across all iterations (and not just be consistent up to NumPy rank/shape
|
||||
@ -1632,7 +1633,7 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the
|
||||
trip count is static (meaning known at tracing time, perhaps because ``lower``
|
||||
and ``upper`` are Python integer literals) then the ``fori_loop`` is
|
||||
implemented in terms of ``scan`` and reverse-mode autodiff is supported;
|
||||
implemented in terms of :func:`~scan` and reverse-mode autodiff is supported;
|
||||
otherwise, a ``while_loop`` is used and reverse-mode autodiff is not
|
||||
supported. See those functions' docstrings for more information.
|
||||
|
||||
@ -1702,19 +1703,20 @@ def map(f, xs):
|
||||
"""Map a function over leading array axes.
|
||||
|
||||
Like Python's builtin map, except inputs and outputs are in the form of
|
||||
stacked arrays. Consider using the ``jax.vmap`` transform instead, unless you
|
||||
stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
|
||||
need to apply a function element by element for reduced memory usage or
|
||||
heterogeneous computation with other control flow primitives.
|
||||
|
||||
When ``xs`` is an array type, the semantics of ``map`` are given by this
|
||||
When ``xs`` is an array type, the semantics of :func:`~map` are given by this
|
||||
Python implementation::
|
||||
|
||||
def map(f, xs):
|
||||
return np.stack([f(x) for x in xs])
|
||||
|
||||
Like ``scan``, ``map`` is implemented in terms of JAX primitives so many of
|
||||
the same advantages over a Python loop apply: ``xs`` may be an arbitrary
|
||||
nested pytree type, and the mapped computation is compiled only once.
|
||||
Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
|
||||
many of the same advantages over a Python loop apply: ``xs`` may be an
|
||||
arbitrary nested pytree type, and the mapped computation is compiled only
|
||||
once.
|
||||
|
||||
Args:
|
||||
f: a Python function to apply element-wise over the first axis or axes of
|
||||
|
@ -209,7 +209,7 @@ def segment_sum(data: Array,
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_sum`` in a ``jit``-compiled function.
|
||||
must be provided to use ``segment_sum`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_sum`` is
|
||||
@ -265,7 +265,7 @@ def segment_prod(data: Array,
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_prod`` in a ``jit``-compiled function.
|
||||
must be provided to use ``segment_prod`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_prod`` is
|
||||
@ -321,7 +321,7 @@ def segment_max(data: Array,
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_max`` in a ``jit``-compiled function.
|
||||
must be provided to use ``segment_max`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_max`` is
|
||||
@ -376,7 +376,7 @@ def segment_min(data: Array,
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_min`` in a ``jit``-compiled function.
|
||||
must be provided to use ``segment_min`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_min`` is
|
||||
|
@ -84,14 +84,14 @@ def start_trace(log_dir, create_perfetto_link: bool = False,
|
||||
"""Starts a profiler trace.
|
||||
|
||||
The trace will capture CPU, GPU, and/or TPU activity, including Python
|
||||
functions and JAX on-device operations. Use ``stop_trace()`` to end the trace
|
||||
functions and JAX on-device operations. Use :func:`stop_trace` to end the trace
|
||||
and save the results to ``log_dir``.
|
||||
|
||||
The resulting trace can be viewed with TensorBoard. Note that TensorBoard
|
||||
doesn't need to be running when collecting the trace.
|
||||
|
||||
Only once trace may be collected a time. A RuntimeError will be raised if
|
||||
``start_trace()`` is called while another trace is running.
|
||||
:func:`start_trace` is called while another trace is running.
|
||||
|
||||
Args:
|
||||
log_dir: The directory to save the profiler trace to (usually the
|
||||
@ -188,7 +188,7 @@ def stop_trace():
|
||||
"""Stops the currently-running profiler trace.
|
||||
|
||||
The trace will be saved to the ``log_dir`` passed to the corresponding
|
||||
``start_trace()`` call. Raises a RuntimeError if a trace hasn't been started.
|
||||
:func:`start_trace` call. Raises a RuntimeError if a trace hasn't been started.
|
||||
"""
|
||||
with _profile_state.lock:
|
||||
if _profile_state.profile_session is None:
|
||||
|
@ -66,7 +66,7 @@ def minimize(
|
||||
- Optimization results may differ from SciPy due to differences in the line
|
||||
search implementation.
|
||||
|
||||
``minimize`` supports ``jit`` compilation. It does not yet support
|
||||
``minimize`` supports :func:`~jax.jit` compilation. It does not yet support
|
||||
differentiation or arguments in the form of multi-dimensional arrays, but
|
||||
support for both is planned.
|
||||
|
||||
|
@ -483,7 +483,7 @@ class Lowered(Stage):
|
||||
carries a lowering together with the remaining information needed to
|
||||
later compile and execute it. It also provides a common API for
|
||||
querying properties of lowered computations across JAX's various
|
||||
lowering paths (``jit``, ``pmap``, etc.).
|
||||
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
|
||||
"""
|
||||
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]
|
||||
|
||||
|
@ -140,19 +140,19 @@ def api_boundary(fun: C) -> C:
|
||||
stack trace of the original exception, but with JAX-internal frames removed.
|
||||
|
||||
This boundary annotation works in composition with itself. The topmost frame
|
||||
corresponding to an ``api_boundary`` is the one below which stack traces are
|
||||
filtered. In other words, if ``api_boundary(f)`` calls ``api_boundary(g)``,
|
||||
directly or indirectly, the filtered stack trace provided is the same as if
|
||||
``api_boundary(f)`` were to simply call ``g`` instead.
|
||||
corresponding to an :func:`~api_boundary` is the one below which stack traces
|
||||
are filtered. In other words, if ``api_boundary(f)`` calls
|
||||
``api_boundary(g)``, directly or indirectly, the filtered stack trace provided
|
||||
is the same as if ``api_boundary(f)`` were to simply call ``g`` instead.
|
||||
|
||||
This annotation is primarily useful in wrapping functions output by JAX's
|
||||
transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is
|
||||
called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
|
||||
in order to trace and translate it. If the function ``f`` raises an exception,
|
||||
the stack unwinds through JAX's JIT internals up to the original call site of
|
||||
``g``. Because the function returned by ``jax.jit`` is annotated as an
|
||||
``api_boundary``, such an exception is accompanied by an additional traceback
|
||||
that excludes the frames specific to JAX's implementation.
|
||||
``g``. Because the function returned by :func:`~jax.jit` is annotated as an
|
||||
:func:`~api_boundary`, such an exception is accompanied by an additional
|
||||
traceback that excludes the frames specific to JAX's implementation.
|
||||
'''
|
||||
|
||||
@util.wraps(fun)
|
||||
|
18
jax/core.py
18
jax/core.py
@ -1019,15 +1019,15 @@ def new_base_main(trace_type: Type[Trace],
|
||||
def ensure_compile_time_eval():
|
||||
"""Context manager to ensure evaluation at trace/compile time (or error).
|
||||
|
||||
Some JAX APIs like ``jax.jit`` and ``jax.lax.scan`` involve staging, i.e.
|
||||
delaying the evaluation of numerical expressions (like jax.numpy function
|
||||
applications) so that instead of performing those computations eagerly while
|
||||
evaluating the corresponding Python expressions, their computation is carried
|
||||
out separately, e.g. after optimized compilation. But this delay can be
|
||||
undesirable. For example, numerical values might be needed to evaluate Python
|
||||
control flow and so their evaluation cannot be delayed. As another example, it
|
||||
may be beneficial to ensure compile time evaluation (or "constant folding")
|
||||
for performance reasons.
|
||||
Some JAX APIs like :func:`jax.jit`` and :func:`jax.lax.scan` involve staging,
|
||||
i.e., delaying the evaluation of numerical expressions (like :mod:`jax.numpy`
|
||||
function applications) so that instead of performing those computations
|
||||
eagerly while evaluating the corresponding Python expressions, their
|
||||
computation is carried out separately, e.g. after optimized compilation. But
|
||||
this delay can be undesirable. For example, numerical values might be needed
|
||||
to evaluate Python control flow and so their evaluation cannot be delayed. As
|
||||
another example, it may be beneficial to ensure compile time evaluation (or
|
||||
"constant folding") for performance reasons.
|
||||
|
||||
This context manager ensures that JAX computations are evaluated eagerly. If
|
||||
eager evaluation is not possible, a ``ConcretizationError`` is raised.
|
||||
|
@ -457,8 +457,8 @@ receiver, it may be hard to debug the calls. In particular, the stack trace
|
||||
will not include the calling code. You can use the flag
|
||||
``jax_host_callback_inline`` (or the environment variable
|
||||
``JAX_HOST_CALLBACK_INLINE``) to ensure that the calls to the callbacks are
|
||||
inlined. This works only if the calls are outside a staging context (``jit``
|
||||
or a control-flow primitive).
|
||||
inlined. This works only if the calls are outside a staging context
|
||||
(:func:`~jax.jit` or a control-flow primitive).
|
||||
|
||||
The C++ `receiver
|
||||
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_
|
||||
|
@ -232,22 +232,23 @@ def pjit(fun: Callable,
|
||||
the output partitioning specified in ``out_axis_resources``. The resources
|
||||
specified in those two arguments must refer to mesh axes, as defined by
|
||||
the :py:func:`jax.experimental.maps.Mesh` context manager. Note that the mesh
|
||||
definition at ``pjit`` application time is ignored, and the returned function
|
||||
definition at :func:`~pjit` application time is ignored, and the returned function
|
||||
will use the mesh definition available at each call site.
|
||||
|
||||
Inputs to a pjit'd function will be automatically partitioned across devices
|
||||
Inputs to a :func:`~pjit`'d function will be automatically partitioned across devices
|
||||
if they're not already correctly partitioned based on ``in_axis_resources``.
|
||||
In some scenarios, ensuring that the inputs are already correctly pre-partitioned
|
||||
can increase performance. For example, if passing the output of one pjit'd function
|
||||
to another pjit’d function (or the same pjit’d function in a loop), make sure the
|
||||
relevant ``out_axis_resources`` match the corresponding ``in_axis_resources``.
|
||||
can increase performance. For example, if passing the output of one
|
||||
:func:`~pjit`'d function to another :func:`~pjit`’d function (or the same
|
||||
:func:`~pjit`’d function in a loop), make sure the relevant
|
||||
``out_axis_resources`` match the corresponding ``in_axis_resources``.
|
||||
|
||||
.. note::
|
||||
**Multi-process platforms:** On multi-process platforms such as TPU pods,
|
||||
``pjit`` can be used to run computations across all available devices across
|
||||
processes. To achieve this, ``pjit`` is designed to be used in SPMD Python
|
||||
:func:`~pjit` can be used to run computations across all available devices across
|
||||
processes. To achieve this, :func:`~pjit` is designed to be used in SPMD Python
|
||||
programs, where every process is running the same Python code such that all
|
||||
processes run the same pjit'd function in the same order.
|
||||
processes run the same :func:`~pjit`'d function in the same order.
|
||||
|
||||
When running in this configuration, the mesh should contain devices across
|
||||
all processes. However, any input argument dimensions partitioned over
|
||||
@ -256,7 +257,7 @@ def pjit(fun: Callable,
|
||||
mesh. ``fun`` will still be executed across *all* devices in the mesh,
|
||||
including those from other processes, and will be given a global view of the
|
||||
data spread across multiple processes as a single array. However, outside
|
||||
of ``pjit`` every process only "sees" its local piece of the input and output,
|
||||
of :func:`~pjit` every process only "sees" its local piece of the input and output,
|
||||
corresponding to its local sub-mesh.
|
||||
|
||||
This means that each process's participating local devices must form a
|
||||
@ -264,7 +265,7 @@ def pjit(fun: Callable,
|
||||
sub-mesh is one where all of its devices are adjacent within the global
|
||||
mesh, and form a rectangular prism.
|
||||
|
||||
The SPMD model also requires that the same multi-process ``pjit``'d
|
||||
The SPMD model also requires that the same multi-process :func:`~pjit`'d
|
||||
functions must be run in the same order on all processes, but they can be
|
||||
interspersed with arbitrary operations running in a single process.
|
||||
|
||||
@ -320,7 +321,7 @@ def pjit(fun: Callable,
|
||||
automaticly partitioned by the mesh available at each call site.
|
||||
|
||||
For example, a convolution operator can be automatically partitioned over
|
||||
an arbitrary set of devices by a single ``pjit`` application:
|
||||
an arbitrary set of devices by a single :func:`~pjit` application:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
"""Utilities for pseudo-random number generation.
|
||||
|
||||
The ``jax.random`` package provides a number of routines for deterministic
|
||||
The :mod:`jax.random` package provides a number of routines for deterministic
|
||||
generation of sequences of pseudorandom numbers.
|
||||
|
||||
Basic usage
|
||||
|
Loading…
x
Reference in New Issue
Block a user