1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

document SPMD pipeline parallelism

PiperOrigin-RevId: 746543312
This commit is contained in:
jax authors 2025-04-11 12:02:51 -07:00
parent ab88273596
commit 8e9fca1d08

@ -243,20 +243,6 @@ Run the real workflow, if you found these loggings in the running log, it means
By adjusting this factor, users can fine-tune the trade-off between memory efficiency
and performance optimizations.
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
i-th layer computation. It also enables overlapping (i+1)-th layer
weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default
value is False. **There are some bugs when this flag is turned on.**
* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when
performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a
nonzero threshold decomposes `CollectivePermute`s into
`CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that
computation can be performed between each corresponding
`ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the
threshold is 0 and there is no decomposition. Setting it to threshold > 0 such
as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this
feature.
* **--xla_gpu_all_gather_combine_threshold_bytes**
**--xla_gpu_reduce_scatter_combine_threshold_bytes**
**--xla_gpu_all_reduce_combine_threshold_bytes**
@ -268,6 +254,227 @@ Run the real workflow, if you found these loggings in the running log, it means
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
default, the `combine_threshold_bytes` is set to 256.
### Pipeline Parallelism on GPU
XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling technique
where the forward and backward pass are split into multiple pipeline stages.
Each device (or device group) processes the result of the previous
pipeline stage (or the pipeline input) and sends its partial result to the next
stage until the end of the pipeline is reached. This optimization works best
when the latency of the computation is larger than communication. At compile
time, the operations will be rearranged to overlap communication with
computation.
For an optimized schedule, we recommend these XLA flags:
```
--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''
--xla_disable_hlo_passes=collective-permute-motion
--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE
```
The following JAX example demonstrates a pattern where communication operations
are scheduled to overlap with computations. In this example we will illustrate
how to set up an optimized pipeline parallelism scheduling using 4 GPUs that
form a communication ring (device 0 -> device 1 -> device 2 -> device 3 ->
device 0). We refer to the pattern `0 -> 1 -> 2 -> 3` as the forward edge, and
`3 -> 0` as the back edge.
```
# Imports and setup
import functools
import jax
from jax import sharding
from jax.experimental import mesh_utils
import jax.numpy as jnp
import jax.random
NUM_DEVICES = 4
NUM_MICROBATCHES = 5
NUM_CIRC_REPEATS = 2
CONTRACTING_DIM_SIZE = 4096
NON_CONTRACTING_DIM_SIZE = 8192
COMPUTE_INTENSITY = 32
# Creates a collective permute for the "forward edge".
# 0->1, 1->2, ... (N-2)->(N-1)
def shift_right(arr):
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
# Use lax.slice to guarantee the gradient is a pad.
return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)
# Creates a collective permute for the "back edge".
# (N-1)->0
def cycle_back(arr):
padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
return jax.lax.slice(
jnp.pad(arr, padding),
[NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
(NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
)
def select_on_first_device(then_value, else_value):
assert then_value.shape == else_value.shape
is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
return jnp.where(is_first_device, then_value, else_value)
def select_on_last_device(then_value, else_value):
assert then_value.shape == else_value.shape
is_last_device = (
jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
)
return jnp.where(is_last_device, then_value, else_value)
def select_on_first_cycle(i, then_value, else_value):
assert then_value.shape == else_value.shape
is_first_cycle = i < NUM_MICROBATCHES
return jnp.where(is_first_cycle, then_value, else_value)
def while_body(carry, i):
"""Body of the pipeline while loop."""
weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry
# Read input data from input buffer.
input_data = jax.lax.dynamic_slice(
input_buffer,
(0, (i + 0) % NUM_MICROBATCHES, 0, 0),
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
)
# Collective permute on the "forward edge" shifts data to the next stage.
fwd_edge_data = shift_right(fwd_edge_data)
# Select compute argument based on device and pipeline cycle.
compute_argument = select_on_first_device(
select_on_first_cycle(i, input_data, bwd_edge_data),
fwd_edge_data,
).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))
# A few matmuls to simulate compute.
tmp = compute_argument
for _ in range(COMPUTE_INTENSITY):
tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
compute_result = tmp.reshape(
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
)
# Read data from buffer to pass it to the first device of the pipeline on the
# "back edge".
bwd_edge_data = jax.lax.dynamic_slice(
output_buffer,
(0, (1 + i) % NUM_MICROBATCHES, 0, 0),
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
)
# Colelctive permute on the "back edge" passes data to the first device.
bwd_edge_data = cycle_back(bwd_edge_data)
# Update output buffer. We do this after reading from it to avoid the data
# dependency.
output_buffer = jax.lax.dynamic_update_slice(
output_buffer,
compute_result,
(0, (2 + i) % NUM_MICROBATCHES, 0, 0),
)
fwd_edge_data = compute_result
carry = (
weights,
input_buffer,
output_buffer,
fwd_edge_data,
bwd_edge_data,
)
return carry, i
@functools.partial(jax.jit, static_argnames=["mesh"])
def entry_computation(weights, input_buffer, mesh):
# Init output buffer.
output_buffer = jnp.zeros_like(input_buffer)
# Init dummy data for forward and backward edge passed through the while loop.
dummy_data = jnp.zeros(
shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
).astype(jnp.float32)
dummy_data = jax.device_put(
dummy_data,
sharding.NamedSharding(
mesh, sharding.PartitionSpec("the_one_and_only_axis")
),
)
# Start pipeline.
carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
_, _, output_buffer, _, _ = carry
return output_buffer
def main(_):
# Expect constant number of devices.
assert NUM_DEVICES == jax.local_device_count()
# Create mesh.
mesh = sharding.Mesh(
mesh_utils.create_device_mesh([NUM_DEVICES]),
axis_names=["the_one_and_only_axis"],
)
# Init weights.
weights = 1.0 / CONTRACTING_DIM_SIZE
weights = jax.lax.broadcast_in_dim(
weights,
shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
broadcast_dimensions=(),
)
weights = jax.device_put(
weights,
sharding.NamedSharding(
mesh, sharding.PartitionSpec("the_one_and_only_axis")
),
)
# Init random input and replicate it across all devices.
random_key = jax.random.key(0)
input_buffer = jax.random.uniform(
random_key,
shape=(
NUM_MICROBATCHES,
CONTRACTING_DIM_SIZE,
NON_CONTRACTING_DIM_SIZE,
),
)
input_buffer = jax.lax.broadcast_in_dim(
input_buffer,
shape=(
NUM_DEVICES,
NUM_MICROBATCHES,
CONTRACTING_DIM_SIZE,
NON_CONTRACTING_DIM_SIZE,
),
broadcast_dimensions=[1, 2, 3],
)
input_buffer = jax.device_put(
input_buffer,
sharding.NamedSharding(
mesh, sharding.PartitionSpec("the_one_and_only_axis")
),
)
# Run computation.
output_buffer = entry_computation(weights, input_buffer, mesh)
print(f"output_buffer = \n{output_buffer}")
```
## NCCL flags
These Nvidia NCCL flag values may be useful for single-host multi-device