[docs] Pmap compiles functions with XLA (#2021)

This commit is contained in:
Jamie Townsend 2020-01-17 17:48:27 +00:00 committed by Matthew Johnson
parent 71323b5d02
commit 3974df0aee
2 changed files with 10 additions and 7 deletions

View File

@ -261,7 +261,9 @@ differentiation for fast Jacobian and Hessian matrix calculations in
For parallel programming of multiple accelerators, like multiple GPUs, use
[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
With `pmap` you write single-program multiple-data (SPMD) programs, including
fast parallel collective communication operations.
fast parallel collective communication operations. Applying `pmap` will mean
that the function you write is compiled by XLA (similarly to `jit`), then
replicated and executed in parallel accross devices.
Here's an example on an 8-GPU machine:

View File

@ -721,12 +721,13 @@ def pmap(fun, axis_name=None, devices=None, backend=None, axis_size=None):
"""Parallel map with support for collectives.
The purpose of ``pmap`` is to express single-program multiple-data (SPMD)
programs and execute them in parallel on XLA devices, such as multiple GPUs or
multiple TPU cores. Semantically it is comparable to ``vmap`` because both
transformations map a function over array axes, but where ``vmap`` vectorizes
functions by pushing the mapped axis down into primitive operations, ``pmap``
instead replicates the function and executes each replica on its own XLA
device in parallel.
programs. Applying ``pmap`` to a function will compile the function with XLA
(similarly to ``jit``), then execute it in parallel on XLA devices, such as
multiple GPUs or multiple TPU cores. Semantically it is comparable to
``vmap`` because both transformations map a function over array axes, but
where ``vmap`` vectorizes functions by pushing the mapped axis down into
primitive operations, ``pmap`` instead replicates the function and executes
each replica on its own XLA device in parallel.
Another key difference with ``vmap`` is that while ``vmap`` can only express
pure maps, ``pmap`` enables the use of parallel SPMD collective operations,