Add a docstring for pjit

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
PiperOrigin-RevId: 372393461
This commit is contained in:
Adam Paszke 2021-05-06 12:00:18 -07:00 committed by jax authors
parent ee94354589
commit ee93ee221c
4 changed files with 120 additions and 1 deletions

View File

@ -0,0 +1,9 @@
jax.experimental.pjit module
============================
.. automodule:: jax.experimental.pjit
API
---
.. autofunction:: pjit

View File

@ -10,6 +10,7 @@ jax.experimental package
jax.experimental.host_callback
jax.experimental.loops
jax.experimental.maps
jax.experimental.pjit
jax.experimental.optimizers
jax.experimental.stax

View File

@ -156,7 +156,7 @@ def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
"""
old_env = getattr(thread_resources, "env", None)
thread_resources.env = ResourceEnv(Mesh(devices, axis_names))
thread_resources.env = ResourceEnv(Mesh(np.asarray(devices, dtype=object), axis_names))
try:
yield
finally:

View File

@ -46,6 +46,115 @@ def pjit(fun: Callable,
out_axis_resources,
static_argnums: Union[int, Sequence[int]] = (),
donate_argnums: Union[int, Sequence[int]] = ()):
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
The returned function has semantics equivalent to those of ``fun``, but is
compiled to an XLA computation that runs across multiple devices
(e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted
version of ``fun`` would not fit in a single device's memory, or to speed up
``fun`` by running each operation in parallel across multiple devices.
The partitioning over devices happens automatically based on
propagation of input partitioning specified in ``in_axis_resources`` and
output partitioning specified in ``out_axis_resources``. The resources
specified in those two arguments must refer to mesh axes, as defined by
the :py:func:`jax.experimental.maps.mesh` context manager. Note that the mesh
definition at ``pjit`` application time is ignored, and the returned function
will use the mesh definition available at each call site.
Inputs to a pjit'd function will be automatically partitioned across devices
if they're not already correctly partitioned based on ``in_axis_resources``.
In some scenarios, ensuring that the inputs are already correctly pre-partitioned
can increase performance. For example, if passing the output of one pjit'd function
to another pjitd function (or the same pjitd function in a loop), make sure the
relevant ``out_axis_resources`` match the corresponding ``in_axis_resources``.
.. note::
**Multi-process platforms:** On multi-process platforms such as TPU pods,
``pjit`` can be used to run computations across all available devices across
processes. To achieve this, ``pjit`` is designed to be used in SPMD Python
programs, where every process is running the same Python code such that all
processes run the same pjit'd function in the same order.
When running in this configuration, the mesh should contain devices across
all processes. However, any input argument dimensions partitioned over
multi-process mesh axes should be of size equal to the corresponding *local*
mesh axis size, and outputs will be similarly sized according to the local
mesh. ``fun`` will still be executed across *all* devices in the mesh,
including those from other processes, and will be given a global view of the
data spread accross multiple processes as a single array. However, outside
of ``pjit`` every process only "sees" its local piece of the input and output,
corresponding to its local sub-mesh.
The SPMD model requires that the same multi-process ``pjit``'d functions must
be run in the same order on all processes, but they can be interspersed with
arbitrary operations running in a single process.
Args:
fun: Function to be compiled. Should be a pure function, as side-effects may
only be executed once. Its arguments and return value should be arrays,
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Positional arguments indicated by ``static_argnums`` can be anything at
all, provided they are hashable and have an equality operation defined.
Static arguments are included as part of a compilation cache key, which is
why hash and equality operators must be defined.
in_axis_resources: Pytree of structure matching that of arguments to ``fun``,
with all actual arguments replaced by resource assignment specifications.
It is also valid to specify a pytree prefix (e.g. one value in place of a
whole subtree), in which case the leaves get broadcast to all values in
that subtree.
The valid resource assignment specifications are:
- :py:obj:`None`, in which case the value will be replicated on all devices
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
axis or a tuple of mesh axes, and specifies the set of resources assigned
to partition the value's dimension matching its position in the spec.
The size of every dimension has to be a multiple of the total number of
resources assigned to it.
out_axis_resources: Like ``in_axis_resources``, but specifies resource
assignment for function outputs.
static_argnums: An optional int or collection of ints that specify which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded in
Python (during tracing), and so the corresponding argument values can be
any Python object.
Static arguments should be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and immutable. Calling the jitted function
with different values for these constants will trigger recompilation.
Arguments that are not arrays or containers thereof must be marked as
static.
If ``static_argnums`` is not provided, no arguments are treated as static.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
buffers to reduce the amount of memory needed to perform a computation,
for example recycling one of your input buffers to store a result. You
should not reuse buffers that you donate to a computation, JAX will raise
an error if you try to.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation and
automatic partitioned by the mesh available at each call site.
For example, a convolution operator can be automatically partitioned over
an arbitrary set of devices by a single ```pjit`` application:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental.maps import mesh
>>> from jax.experimental.pjit import PartitionSpec, pjit
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
... in_axis_resources=None, out_axis_resources=PartitionSpec('devices'))
>>> with mesh(jax.devices(), ('devices',)):
... print(f(x)) # doctest: +SKIP
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
"""
warn("pjit is an experimental feature and probably has bugs!")
_check_callable(fun)