mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[docs] Pmap compiles functions with XLA (#2021)
This commit is contained in:
parent
71323b5d02
commit
3974df0aee
@ -261,7 +261,9 @@ differentiation for fast Jacobian and Hessian matrix calculations in
|
||||
For parallel programming of multiple accelerators, like multiple GPUs, use
|
||||
[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
|
||||
With `pmap` you write single-program multiple-data (SPMD) programs, including
|
||||
fast parallel collective communication operations.
|
||||
fast parallel collective communication operations. Applying `pmap` will mean
|
||||
that the function you write is compiled by XLA (similarly to `jit`), then
|
||||
replicated and executed in parallel accross devices.
|
||||
|
||||
Here's an example on an 8-GPU machine:
|
||||
|
||||
|
13
jax/api.py
13
jax/api.py
@ -721,12 +721,13 @@ def pmap(fun, axis_name=None, devices=None, backend=None, axis_size=None):
|
||||
"""Parallel map with support for collectives.
|
||||
|
||||
The purpose of ``pmap`` is to express single-program multiple-data (SPMD)
|
||||
programs and execute them in parallel on XLA devices, such as multiple GPUs or
|
||||
multiple TPU cores. Semantically it is comparable to ``vmap`` because both
|
||||
transformations map a function over array axes, but where ``vmap`` vectorizes
|
||||
functions by pushing the mapped axis down into primitive operations, ``pmap``
|
||||
instead replicates the function and executes each replica on its own XLA
|
||||
device in parallel.
|
||||
programs. Applying ``pmap`` to a function will compile the function with XLA
|
||||
(similarly to ``jit``), then execute it in parallel on XLA devices, such as
|
||||
multiple GPUs or multiple TPU cores. Semantically it is comparable to
|
||||
``vmap`` because both transformations map a function over array axes, but
|
||||
where ``vmap`` vectorizes functions by pushing the mapped axis down into
|
||||
primitive operations, ``pmap`` instead replicates the function and executes
|
||||
each replica on its own XLA device in parallel.
|
||||
|
||||
Another key difference with ``vmap`` is that while ``vmap`` can only express
|
||||
pure maps, ``pmap`` enables the use of parallel SPMD collective operations,
|
||||
|
Loading…
x
Reference in New Issue
Block a user