Merge pull request #21516 from nouiz:paralell_computation

PiperOrigin-RevId: 642004618
This commit is contained in:
jax authors 2024-06-10 13:29:10 -07:00
commit af004302c1
2 changed files with 23 additions and 22 deletions

View File

@ -111,3 +111,12 @@ os.environ.update({
These NCCL flags could improve single-host communication speed. These flags
don't seem useful for multi-host communication yet.
## Multi-Process
We recommand using one process per GPU and not one per node. In some
cases, this can speed up jitted computation. The
{func}`jax.distributed.initialize` API will automatically understand
that configuration when run under SLURM. However, this only a rule of
thumb and it may be useful to test both one process per GPU and one
process per node on your use case.

View File

@ -28,12 +28,15 @@ Key concepts:
* Each process has a
distinct set of *local* devices it can address. The *global* devices are the set
of all devices across all processes.
* Use standard JAX parallelism APIs like {func}`~jax.pmap` and
{func}`~jax.experimental.maps.xmap`. Each process “sees” *local* input and
output to parallelized functions, but communication inside the computations
is *global*.
* Use standard JAX parallelism APIs like {func}`~jax.jit` (see
{doc}`/sharded-computation` tutorial) and
{func}`~jax.experimental.shard_map.shard_map`. jax.jit only accepts
globally shaped arrays. shard_map allows you to drop to per-device
shape.
* Make sure all processes run the same parallel computations in the same
order.
* Make sure all processes has the same number of local devices.
* Make sure all devices are the same (e.g., all V100, or all H100).
### Launching JAX processes
@ -123,18 +126,13 @@ global devices.
So how do you actually run a computation involving cross-process communication?
**Use the same parallel evaluation APIs that you would in a single process!**
For example, {func}`~jax.experimental.shard_map.shard_map` can be used to
run a parallel computation across
multiple processes. (If youre not already familiar with how to use
`shard_map` to run across multiple devices within a single process, check
out the {doc}`/sharded-computation` tutorial.) Each process should call the
same pmapped function and pass in arguments to be mapped across its *local*
devices (i.e., the pmapped axis size is equal to the number of local devices).
Similarly, the function will return outputs sharded across *local* devices only.
Inside the function, however, collective communication operations are run across
all *global* devices, across all processes. Conceptually, this can be thought of
as running a pmap over a single array sharded across hosts, where each host
“sees” only its local shard of the input and output.
For example, {func}`~jax.experimental.shard_map.shard_map` can be used
to run a parallel computation across multiple processes. (If youre
not already familiar with how to use `shard_map` to run across
multiple devices within a single process, check out the
{doc}`/sharded-computation` tutorial.) Conceptually, this can be
thought of as running a pmap over a single array sharded across hosts,
where each host “sees” only its local shard of the input and output.
Heres an example of multi-process pmap in action:
@ -152,12 +150,6 @@ Heres an example of multi-process pmap in action:
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
```
{func}`~jax.experimental.maps.xmap` works similarly when using a physical
hardware mesh (see the {doc}`xmap tutorial</notebooks/xmap_tutorial>` if youre
not familiar with the single-process version). Like {func}`~jax.pmap` , the
inputs and outputs are local and any parallel communication inside the xmapped
function is global. The mesh is also global.
**Its very important that all processes run the same cross-process computations
in the same order.** Running the same JAX Python program in each process is
usually sufficient. Some common pitfalls to look out for that may cause