mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #21516 from nouiz:paralell_computation
PiperOrigin-RevId: 642004618
This commit is contained in:
commit
af004302c1
@ -111,3 +111,12 @@ os.environ.update({
|
||||
|
||||
These NCCL flags could improve single-host communication speed. These flags
|
||||
don't seem useful for multi-host communication yet.
|
||||
|
||||
## Multi-Process
|
||||
|
||||
We recommand using one process per GPU and not one per node. In some
|
||||
cases, this can speed up jitted computation. The
|
||||
{func}`jax.distributed.initialize` API will automatically understand
|
||||
that configuration when run under SLURM. However, this only a rule of
|
||||
thumb and it may be useful to test both one process per GPU and one
|
||||
process per node on your use case.
|
||||
|
@ -28,12 +28,15 @@ Key concepts:
|
||||
* Each process has a
|
||||
distinct set of *local* devices it can address. The *global* devices are the set
|
||||
of all devices across all processes.
|
||||
* Use standard JAX parallelism APIs like {func}`~jax.pmap` and
|
||||
{func}`~jax.experimental.maps.xmap`. Each process “sees” *local* input and
|
||||
output to parallelized functions, but communication inside the computations
|
||||
is *global*.
|
||||
* Use standard JAX parallelism APIs like {func}`~jax.jit` (see
|
||||
{doc}`/sharded-computation` tutorial) and
|
||||
{func}`~jax.experimental.shard_map.shard_map`. jax.jit only accepts
|
||||
globally shaped arrays. shard_map allows you to drop to per-device
|
||||
shape.
|
||||
* Make sure all processes run the same parallel computations in the same
|
||||
order.
|
||||
* Make sure all processes has the same number of local devices.
|
||||
* Make sure all devices are the same (e.g., all V100, or all H100).
|
||||
|
||||
### Launching JAX processes
|
||||
|
||||
@ -123,18 +126,13 @@ global devices.
|
||||
So how do you actually run a computation involving cross-process communication?
|
||||
**Use the same parallel evaluation APIs that you would in a single process!**
|
||||
|
||||
For example, {func}`~jax.experimental.shard_map.shard_map` can be used to
|
||||
run a parallel computation across
|
||||
multiple processes. (If you’re not already familiar with how to use
|
||||
`shard_map` to run across multiple devices within a single process, check
|
||||
out the {doc}`/sharded-computation` tutorial.) Each process should call the
|
||||
same pmapped function and pass in arguments to be mapped across its *local*
|
||||
devices (i.e., the pmapped axis size is equal to the number of local devices).
|
||||
Similarly, the function will return outputs sharded across *local* devices only.
|
||||
Inside the function, however, collective communication operations are run across
|
||||
all *global* devices, across all processes. Conceptually, this can be thought of
|
||||
as running a pmap over a single array sharded across hosts, where each host
|
||||
“sees” only its local shard of the input and output.
|
||||
For example, {func}`~jax.experimental.shard_map.shard_map` can be used
|
||||
to run a parallel computation across multiple processes. (If you’re
|
||||
not already familiar with how to use `shard_map` to run across
|
||||
multiple devices within a single process, check out the
|
||||
{doc}`/sharded-computation` tutorial.) Conceptually, this can be
|
||||
thought of as running a pmap over a single array sharded across hosts,
|
||||
where each host “sees” only its local shard of the input and output.
|
||||
|
||||
Here’s an example of multi-process pmap in action:
|
||||
|
||||
@ -152,12 +150,6 @@ Here’s an example of multi-process pmap in action:
|
||||
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
|
||||
```
|
||||
|
||||
{func}`~jax.experimental.maps.xmap` works similarly when using a physical
|
||||
hardware mesh (see the {doc}`xmap tutorial</notebooks/xmap_tutorial>` if you’re
|
||||
not familiar with the single-process version). Like {func}`~jax.pmap` , the
|
||||
inputs and outputs are local and any parallel communication inside the xmapped
|
||||
function is global. The mesh is also global.
|
||||
|
||||
**It’s very important that all processes run the same cross-process computations
|
||||
in the same order.** Running the same JAX Python program in each process is
|
||||
usually sufficient. Some common pitfalls to look out for that may cause
|
||||
|
Loading…
x
Reference in New Issue
Block a user