mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add a docstring for pjit
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> PiperOrigin-RevId: 372393461
This commit is contained in:
parent
ee94354589
commit
ee93ee221c
9
docs/jax.experimental.pjit.rst
Normal file
9
docs/jax.experimental.pjit.rst
Normal file
@ -0,0 +1,9 @@
|
||||
jax.experimental.pjit module
|
||||
============================
|
||||
|
||||
.. automodule:: jax.experimental.pjit
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autofunction:: pjit
|
@ -10,6 +10,7 @@ jax.experimental package
|
||||
jax.experimental.host_callback
|
||||
jax.experimental.loops
|
||||
jax.experimental.maps
|
||||
jax.experimental.pjit
|
||||
jax.experimental.optimizers
|
||||
jax.experimental.stax
|
||||
|
||||
|
@ -156,7 +156,7 @@ def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
|
||||
axis_resources={'left': 'x', 'right': 'y'})(x, x.T)
|
||||
"""
|
||||
old_env = getattr(thread_resources, "env", None)
|
||||
thread_resources.env = ResourceEnv(Mesh(devices, axis_names))
|
||||
thread_resources.env = ResourceEnv(Mesh(np.asarray(devices, dtype=object), axis_names))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
@ -46,6 +46,115 @@ def pjit(fun: Callable,
|
||||
out_axis_resources,
|
||||
static_argnums: Union[int, Sequence[int]] = (),
|
||||
donate_argnums: Union[int, Sequence[int]] = ()):
|
||||
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
||||
|
||||
The returned function has semantics equivalent to those of ``fun``, but is
|
||||
compiled to an XLA computation that runs across multiple devices
|
||||
(e.g. multiple GPUs or multiple TPU cores). This can be useful if the jitted
|
||||
version of ``fun`` would not fit in a single device's memory, or to speed up
|
||||
``fun`` by running each operation in parallel across multiple devices.
|
||||
|
||||
The partitioning over devices happens automatically based on
|
||||
propagation of input partitioning specified in ``in_axis_resources`` and
|
||||
output partitioning specified in ``out_axis_resources``. The resources
|
||||
specified in those two arguments must refer to mesh axes, as defined by
|
||||
the :py:func:`jax.experimental.maps.mesh` context manager. Note that the mesh
|
||||
definition at ``pjit`` application time is ignored, and the returned function
|
||||
will use the mesh definition available at each call site.
|
||||
|
||||
Inputs to a pjit'd function will be automatically partitioned across devices
|
||||
if they're not already correctly partitioned based on ``in_axis_resources``.
|
||||
In some scenarios, ensuring that the inputs are already correctly pre-partitioned
|
||||
can increase performance. For example, if passing the output of one pjit'd function
|
||||
to another pjit’d function (or the same pjit’d function in a loop), make sure the
|
||||
relevant ``out_axis_resources`` match the corresponding ``in_axis_resources``.
|
||||
|
||||
.. note::
|
||||
**Multi-process platforms:** On multi-process platforms such as TPU pods,
|
||||
``pjit`` can be used to run computations across all available devices across
|
||||
processes. To achieve this, ``pjit`` is designed to be used in SPMD Python
|
||||
programs, where every process is running the same Python code such that all
|
||||
processes run the same pjit'd function in the same order.
|
||||
|
||||
When running in this configuration, the mesh should contain devices across
|
||||
all processes. However, any input argument dimensions partitioned over
|
||||
multi-process mesh axes should be of size equal to the corresponding *local*
|
||||
mesh axis size, and outputs will be similarly sized according to the local
|
||||
mesh. ``fun`` will still be executed across *all* devices in the mesh,
|
||||
including those from other processes, and will be given a global view of the
|
||||
data spread accross multiple processes as a single array. However, outside
|
||||
of ``pjit`` every process only "sees" its local piece of the input and output,
|
||||
corresponding to its local sub-mesh.
|
||||
|
||||
The SPMD model requires that the same multi-process ``pjit``'d functions must
|
||||
be run in the same order on all processes, but they can be interspersed with
|
||||
arbitrary operations running in a single process.
|
||||
|
||||
Args:
|
||||
fun: Function to be compiled. Should be a pure function, as side-effects may
|
||||
only be executed once. Its arguments and return value should be arrays,
|
||||
scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
|
||||
Positional arguments indicated by ``static_argnums`` can be anything at
|
||||
all, provided they are hashable and have an equality operation defined.
|
||||
Static arguments are included as part of a compilation cache key, which is
|
||||
why hash and equality operators must be defined.
|
||||
in_axis_resources: Pytree of structure matching that of arguments to ``fun``,
|
||||
with all actual arguments replaced by resource assignment specifications.
|
||||
It is also valid to specify a pytree prefix (e.g. one value in place of a
|
||||
whole subtree), in which case the leaves get broadcast to all values in
|
||||
that subtree.
|
||||
|
||||
The valid resource assignment specifications are:
|
||||
- :py:obj:`None`, in which case the value will be replicated on all devices
|
||||
- :py:class:`PartitionSpec`, a tuple of length at most equal to the rank
|
||||
of the partitioned value. Each element can be a :py:obj:`None`, a mesh
|
||||
axis or a tuple of mesh axes, and specifies the set of resources assigned
|
||||
to partition the value's dimension matching its position in the spec.
|
||||
|
||||
The size of every dimension has to be a multiple of the total number of
|
||||
resources assigned to it.
|
||||
out_axis_resources: Like ``in_axis_resources``, but specifies resource
|
||||
assignment for function outputs.
|
||||
static_argnums: An optional int or collection of ints that specify which
|
||||
positional arguments to treat as static (compile-time constant).
|
||||
Operations that only depend on static arguments will be constant-folded in
|
||||
Python (during tracing), and so the corresponding argument values can be
|
||||
any Python object.
|
||||
|
||||
Static arguments should be hashable, meaning both ``__hash__`` and
|
||||
``__eq__`` are implemented, and immutable. Calling the jitted function
|
||||
with different values for these constants will trigger recompilation.
|
||||
Arguments that are not arrays or containers thereof must be marked as
|
||||
static.
|
||||
|
||||
If ``static_argnums`` is not provided, no arguments are treated as static.
|
||||
donate_argnums: Specify which arguments are "donated" to the computation.
|
||||
It is safe to donate arguments if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
buffers to reduce the amount of memory needed to perform a computation,
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not reuse buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun``, set up for just-in-time compilation and
|
||||
automatic partitioned by the mesh available at each call site.
|
||||
|
||||
For example, a convolution operator can be automatically partitioned over
|
||||
an arbitrary set of devices by a single ```pjit`` application:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax.experimental.maps import mesh
|
||||
>>> from jax.experimental.pjit import PartitionSpec, pjit
|
||||
>>>
|
||||
>>> x = jnp.arange(8, dtype=jnp.float32)
|
||||
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
|
||||
... in_axis_resources=None, out_axis_resources=PartitionSpec('devices'))
|
||||
>>> with mesh(jax.devices(), ('devices',)):
|
||||
... print(f(x)) # doctest: +SKIP
|
||||
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
|
||||
"""
|
||||
warn("pjit is an experimental feature and probably has bugs!")
|
||||
_check_callable(fun)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user