mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
document SPMD pipeline parallelism
PiperOrigin-RevId: 746543312
This commit is contained in:
parent
ab88273596
commit
8e9fca1d08
@ -243,20 +243,6 @@ Run the real workflow, if you found these loggings in the running log, it means
|
||||
|
||||
By adjusting this factor, users can fine-tune the trade-off between memory efficiency
|
||||
and performance optimizations.
|
||||
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
|
||||
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
|
||||
i-th layer computation. It also enables overlapping (i+1)-th layer
|
||||
weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default
|
||||
value is False. **There are some bugs when this flag is turned on.**
|
||||
* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when
|
||||
performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a
|
||||
nonzero threshold decomposes `CollectivePermute`s into
|
||||
`CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that
|
||||
computation can be performed between each corresponding
|
||||
`ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the
|
||||
threshold is 0 and there is no decomposition. Setting it to threshold > 0 such
|
||||
as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this
|
||||
feature.
|
||||
* **--xla_gpu_all_gather_combine_threshold_bytes**
|
||||
**--xla_gpu_reduce_scatter_combine_threshold_bytes**
|
||||
**--xla_gpu_all_reduce_combine_threshold_bytes**
|
||||
@ -268,6 +254,227 @@ Run the real workflow, if you found these loggings in the running log, it means
|
||||
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
|
||||
default, the `combine_threshold_bytes` is set to 256.
|
||||
|
||||
### Pipeline Parallelism on GPU
|
||||
|
||||
XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling technique
|
||||
where the forward and backward pass are split into multiple pipeline stages.
|
||||
Each device (or device group) processes the result of the previous
|
||||
pipeline stage (or the pipeline input) and sends its partial result to the next
|
||||
stage until the end of the pipeline is reached. This optimization works best
|
||||
when the latency of the computation is larger than communication. At compile
|
||||
time, the operations will be rearranged to overlap communication with
|
||||
computation.
|
||||
|
||||
For an optimized schedule, we recommend these XLA flags:
|
||||
```
|
||||
--xla_gpu_enable_latency_hiding_scheduler=true
|
||||
--xla_gpu_enable_command_buffer=''
|
||||
--xla_disable_hlo_passes=collective-permute-motion
|
||||
--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE
|
||||
```
|
||||
|
||||
The following JAX example demonstrates a pattern where communication operations
|
||||
are scheduled to overlap with computations. In this example we will illustrate
|
||||
how to set up an optimized pipeline parallelism scheduling using 4 GPUs that
|
||||
form a communication ring (device 0 -> device 1 -> device 2 -> device 3 ->
|
||||
device 0). We refer to the pattern `0 -> 1 -> 2 -> 3` as the forward edge, and
|
||||
`3 -> 0` as the back edge.
|
||||
|
||||
```
|
||||
# Imports and setup
|
||||
import functools
|
||||
import jax
|
||||
from jax import sharding
|
||||
from jax.experimental import mesh_utils
|
||||
import jax.numpy as jnp
|
||||
import jax.random
|
||||
|
||||
NUM_DEVICES = 4
|
||||
NUM_MICROBATCHES = 5
|
||||
NUM_CIRC_REPEATS = 2
|
||||
CONTRACTING_DIM_SIZE = 4096
|
||||
NON_CONTRACTING_DIM_SIZE = 8192
|
||||
COMPUTE_INTENSITY = 32
|
||||
|
||||
# Creates a collective permute for the "forward edge".
|
||||
# 0->1, 1->2, ... (N-2)->(N-1)
|
||||
def shift_right(arr):
|
||||
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
|
||||
# Use lax.slice to guarantee the gradient is a pad.
|
||||
return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)
|
||||
|
||||
|
||||
# Creates a collective permute for the "back edge".
|
||||
# (N-1)->0
|
||||
def cycle_back(arr):
|
||||
padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
|
||||
return jax.lax.slice(
|
||||
jnp.pad(arr, padding),
|
||||
[NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
|
||||
(NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
|
||||
)
|
||||
|
||||
|
||||
def select_on_first_device(then_value, else_value):
|
||||
assert then_value.shape == else_value.shape
|
||||
is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
|
||||
return jnp.where(is_first_device, then_value, else_value)
|
||||
|
||||
|
||||
def select_on_last_device(then_value, else_value):
|
||||
assert then_value.shape == else_value.shape
|
||||
is_last_device = (
|
||||
jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
|
||||
)
|
||||
return jnp.where(is_last_device, then_value, else_value)
|
||||
|
||||
|
||||
def select_on_first_cycle(i, then_value, else_value):
|
||||
assert then_value.shape == else_value.shape
|
||||
is_first_cycle = i < NUM_MICROBATCHES
|
||||
return jnp.where(is_first_cycle, then_value, else_value)
|
||||
|
||||
|
||||
def while_body(carry, i):
|
||||
"""Body of the pipeline while loop."""
|
||||
weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry
|
||||
|
||||
# Read input data from input buffer.
|
||||
input_data = jax.lax.dynamic_slice(
|
||||
input_buffer,
|
||||
(0, (i + 0) % NUM_MICROBATCHES, 0, 0),
|
||||
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
|
||||
)
|
||||
|
||||
# Collective permute on the "forward edge" shifts data to the next stage.
|
||||
fwd_edge_data = shift_right(fwd_edge_data)
|
||||
|
||||
# Select compute argument based on device and pipeline cycle.
|
||||
compute_argument = select_on_first_device(
|
||||
select_on_first_cycle(i, input_data, bwd_edge_data),
|
||||
fwd_edge_data,
|
||||
).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))
|
||||
|
||||
# A few matmuls to simulate compute.
|
||||
tmp = compute_argument
|
||||
for _ in range(COMPUTE_INTENSITY):
|
||||
tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
|
||||
compute_result = tmp.reshape(
|
||||
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
|
||||
)
|
||||
|
||||
# Read data from buffer to pass it to the first device of the pipeline on the
|
||||
# "back edge".
|
||||
bwd_edge_data = jax.lax.dynamic_slice(
|
||||
output_buffer,
|
||||
(0, (1 + i) % NUM_MICROBATCHES, 0, 0),
|
||||
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
|
||||
)
|
||||
|
||||
# Colelctive permute on the "back edge" passes data to the first device.
|
||||
bwd_edge_data = cycle_back(bwd_edge_data)
|
||||
|
||||
# Update output buffer. We do this after reading from it to avoid the data
|
||||
# dependency.
|
||||
output_buffer = jax.lax.dynamic_update_slice(
|
||||
output_buffer,
|
||||
compute_result,
|
||||
(0, (2 + i) % NUM_MICROBATCHES, 0, 0),
|
||||
)
|
||||
|
||||
fwd_edge_data = compute_result
|
||||
carry = (
|
||||
weights,
|
||||
input_buffer,
|
||||
output_buffer,
|
||||
fwd_edge_data,
|
||||
bwd_edge_data,
|
||||
)
|
||||
return carry, i
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["mesh"])
|
||||
def entry_computation(weights, input_buffer, mesh):
|
||||
|
||||
# Init output buffer.
|
||||
output_buffer = jnp.zeros_like(input_buffer)
|
||||
|
||||
# Init dummy data for forward and backward edge passed through the while loop.
|
||||
dummy_data = jnp.zeros(
|
||||
shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
|
||||
).astype(jnp.float32)
|
||||
dummy_data = jax.device_put(
|
||||
dummy_data,
|
||||
sharding.NamedSharding(
|
||||
mesh, sharding.PartitionSpec("the_one_and_only_axis")
|
||||
),
|
||||
)
|
||||
|
||||
# Start pipeline.
|
||||
carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
|
||||
num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
|
||||
carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
|
||||
_, _, output_buffer, _, _ = carry
|
||||
|
||||
return output_buffer
|
||||
|
||||
|
||||
def main(_):
|
||||
|
||||
# Expect constant number of devices.
|
||||
assert NUM_DEVICES == jax.local_device_count()
|
||||
|
||||
# Create mesh.
|
||||
mesh = sharding.Mesh(
|
||||
mesh_utils.create_device_mesh([NUM_DEVICES]),
|
||||
axis_names=["the_one_and_only_axis"],
|
||||
)
|
||||
|
||||
# Init weights.
|
||||
weights = 1.0 / CONTRACTING_DIM_SIZE
|
||||
weights = jax.lax.broadcast_in_dim(
|
||||
weights,
|
||||
shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
|
||||
broadcast_dimensions=(),
|
||||
)
|
||||
weights = jax.device_put(
|
||||
weights,
|
||||
sharding.NamedSharding(
|
||||
mesh, sharding.PartitionSpec("the_one_and_only_axis")
|
||||
),
|
||||
)
|
||||
|
||||
# Init random input and replicate it across all devices.
|
||||
random_key = jax.random.key(0)
|
||||
input_buffer = jax.random.uniform(
|
||||
random_key,
|
||||
shape=(
|
||||
NUM_MICROBATCHES,
|
||||
CONTRACTING_DIM_SIZE,
|
||||
NON_CONTRACTING_DIM_SIZE,
|
||||
),
|
||||
)
|
||||
input_buffer = jax.lax.broadcast_in_dim(
|
||||
input_buffer,
|
||||
shape=(
|
||||
NUM_DEVICES,
|
||||
NUM_MICROBATCHES,
|
||||
CONTRACTING_DIM_SIZE,
|
||||
NON_CONTRACTING_DIM_SIZE,
|
||||
),
|
||||
broadcast_dimensions=[1, 2, 3],
|
||||
)
|
||||
input_buffer = jax.device_put(
|
||||
input_buffer,
|
||||
sharding.NamedSharding(
|
||||
mesh, sharding.PartitionSpec("the_one_and_only_axis")
|
||||
),
|
||||
)
|
||||
|
||||
# Run computation.
|
||||
output_buffer = entry_computation(weights, input_buffer, mesh)
|
||||
print(f"output_buffer = \n{output_buffer}")
|
||||
```
|
||||
## NCCL flags
|
||||
|
||||
These Nvidia NCCL flag values may be useful for single-host multi-device
|
||||
|
Loading…
x
Reference in New Issue
Block a user