Merge pull request #26564 from gspschmid:gschmid/mini_mpmd

PiperOrigin-RevId: 730912043
This commit is contained in:
jax authors 2025-02-25 09:15:40 -08:00
commit a6b8384aed
10 changed files with 1555 additions and 0 deletions

View File

@ -0,0 +1,41 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"])
package(
default_applicable_licenses = [],
)
py_library(
name = "private_mm",
srcs = ["__init__.py"],
deps = [":private_mm_internal"],
)
py_library(
name = "private_mm_internal",
srcs = [
"example_utils.py",
"mini_dime.py",
"mm.py",
"profile_utils.py",
],
tags = ["pytype_unchecked_annotations"],
deps = [
"//jax",
],
)

View File

@ -0,0 +1,20 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental._private_mm.mm import (
device_put as device_put,
jit as jit,
MpmdArray as MpmdArray
)
from jax.experimental._private_mm import profile_utils as profile_utils

View File

@ -0,0 +1,79 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A basic educational example."""
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import _private_mm as mm
from jax.experimental._private_mm.examples import launch_utils
def step():
devices = jax.devices()
mesh1 = Mesh(devices[:4], ('data',))
mesh2 = Mesh(devices[4:], ('data',))
sharding1 = NamedSharding(mesh1, P('data'))
sharding2 = NamedSharding(mesh2, P('data'))
shape = (512, 2**20)
@partial(mm.jit, in_shardings=sharding1, out_shardings=sharding1)
def stage1(x):
return x + 1
@partial(mm.jit, in_shardings=sharding2, out_shardings=sharding2)
def stage2(x):
return x * 2
a0: mm.MpmdArray = mm.device_put(jnp.zeros(shape), sharding1)
b0: mm.MpmdArray = mm.device_put(jnp.ones(shape), sharding1)
# Enqueue all work on [a]
a1 = stage1(a0)
a1 = mm.device_put(a1, sharding2)
a2 = stage2(a1)
# Enqueue all work on [b]
b1 = stage1(b0)
b1 = mm.device_put(b1, sharding2)
b2 = stage2(b1)
# Only print if a2/b2 resident (i.e., we belong to the last stage):
if not a2.is_fully_remote:
assert not b2.is_fully_remote
print(a2.jax_array)
print(b2.jax_array)
def example_basic(num_processes, process_id):
assert jax.device_count() == 8
# FIXME: Support stages spread across multiple processes.
assert 2 % num_processes == 0
for i in range(3):
step()
if __name__ == '__main__':
import sys
num_processes = 2
if len(sys.argv) >= 2:
num_processes = int(sys.argv[1])
success = launch_utils.launch_example(num_processes, example_basic)
sys.exit(0 if success else 1)

View File

@ -0,0 +1,191 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example showcasing overlap on a (forward-only) PP-like workload."""
from dataclasses import dataclass
from typing import Any, Callable
import time
import numpy as np
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import _private_mm as mm
from jax.experimental._private_mm import profile_utils
from jax.experimental._private_mm.examples import launch_utils
@dataclass(frozen=True)
class Stage:
fwd: Callable[[Any, Any], Any] # (params, acts) -> acts
mesh: Mesh
def transfer(arr, stage):
sharding = NamedSharding(stage.mesh, P()) # just replicate
return mm.device_put(arr, device=sharding)
def stages_step_fn(stages, num_mubatches, params_by_stage, xs):
# One task per mubatch and stage (e.g. forward stages only)
tasks = [
(mubatch_idx, stage_idx)
for stage_idx in range(len(stages))
for mubatch_idx in range(num_mubatches)
]
# We want to be careful with the order in which we enqueue work, since
# a single process is managing multiple devices.
# Assuming a GPipe-like schedule we traverse tasks in the following order:
# t=0 t=1 t=2 t=3 t=4 t=5 t=6
# stage=0 1 2 4 7
# stage=1 3 5 8 11
# stage=2 6 9 12 14
# stage=3 10 13 15 16
def task_key(task):
mubatch_idx, stage_idx = task
return (mubatch_idx + stage_idx, stage_idx)
tasks.sort(key=task_key)
input = {
(mubatch_idx, 0): xs
for mubatch_idx in range(num_mubatches)
}
for task_id in tasks:
mubatch_idx, stage_idx = task_id
stage = stages[stage_idx]
params = params_by_stage[stage_idx]
with profile_utils.annotate(
f'mub{mubatch_idx}/F{stage_idx}',
color='cyan',
):
# Invoke the stage and immediately enqueue the transfer of the
# result to the next stage. We want the transfer to be overlapped
# with subsequent computation on the same stage.
local_output = stage.fwd(params, input[task_id])
if stage_idx + 1 < len(stages):
with profile_utils.annotate(
f'Tx mub{mubatch_idx} to {stage_idx+1}',
color='yellow',
):
input[(mubatch_idx, stage_idx+1)] = transfer(
local_output,
stages[stage_idx+1],
)
return local_output
### Example usage
def example_overlap(num_processes, process_id):
assert jax.device_count() == 8
NUM_STAGES = 4
NUM_MUBATCHES = 4
# FIXME: Support stages spread across multiple processes.
assert NUM_STAGES % num_processes == 0
# Takes ~5ms/stage/microbatch on H100s:
LAYER_SIZE = 8192
# # a) Several layers per stage, little communication (32MB activations)
# NUM_LAYERS = NUM_STAGES * 16
# BATCH_SIZE = 1024
# b) One layer per stage, more communication (512MB activations)
NUM_LAYERS = NUM_STAGES
BATCH_SIZE = 1024 * 16
def mlp(params, xs):
for W in params:
xs = xs @ W
return xs
def init_params(key):
params = []
for _ in range(NUM_LAYERS):
key, key_W = jax.random.split(key)
params.append(jax.random.normal(key_W, (LAYER_SIZE, LAYER_SIZE)))
return params, key
# Two devices per stage (running fully-replicated)
num_devices_per_stage = 2
stages = []
for i in range(NUM_STAGES):
devices = jax.devices()[
num_devices_per_stage*i : num_devices_per_stage*(i+1)
]
assert all(d.process_index == devices[0].process_index for d in devices)
mesh = Mesh(np.asarray(devices), ('repl',))
jitted_fun = mm.jit(
mlp,
in_shardings=(NamedSharding(mesh, P()), NamedSharding(mesh, P())),
out_shardings=NamedSharding(mesh, P()),
)
stages.append(Stage(jitted_fun, mesh))
def step_fn(params_by_stage, xs):
return stages_step_fn(stages, NUM_MUBATCHES, params_by_stage, xs)
def shard_params_by_stage(params):
num_per_stage, rem = divmod(len(params), NUM_STAGES)
assert num_per_stage > 0
assert rem == 0
params_by_stage = [
jax.tree.map(
lambda arr: transfer(arr, stages[stage_idx]),
params[num_per_stage*stage_idx:num_per_stage*(stage_idx+1)],
)
for stage_idx in range(NUM_STAGES)
]
return params_by_stage
key = jax.random.PRNGKey(0)
params, key = init_params(key)
params_by_stage = shard_params_by_stage(params)
key, key_xs = jax.random.split(key)
xs_batch = jax.random.uniform(key_xs, (BATCH_SIZE, LAYER_SIZE))
NUM_STEPS = 50
NUM_STEPS_PROFILED = 3
for i in range(NUM_STEPS):
print(f'===== STEP {i} {process_id=} =====')
if i == 1:
# The overhead from compilations during warm-up ends up
# staggering executions on devices of the same stage. The sleep
# below allows them to catch up. In a real model collectives
# within each stage would likely have the same effect of keeping
# devices in sync.
time.sleep(0.2)
if i == NUM_STEPS - NUM_STEPS_PROFILED:
profile_utils.maybe_start_profile(f"overlap_trace/p{process_id}")
xs_batch = transfer(xs_batch, stages[0])
with profile_utils.annotate(f'step{i}', color='white'):
xs_batch = step_fn(params_by_stage, xs_batch)
if __name__ == '__main__':
import sys
num_processes = 4
if len(sys.argv) >= 2:
num_processes = int(sys.argv[1])
success = launch_utils.launch_example(num_processes, example_overlap)
sys.exit(0 if success else 1)

View File

@ -0,0 +1,361 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A toy model with MPMD pipeline parallelism."""
from dataclasses import dataclass
from functools import cached_property, partial
from typing import Any, Callable
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import _private_mm as mm
from jax.experimental._private_mm import profile_utils
from jax.experimental._private_mm.examples import launch_utils
LR = 0.01
@dataclass(frozen=True)
class Stage:
raw_fwd: Callable[[Any, Any, Any], Any] # (params, acts, ys) -> acts
mesh: Mesh
params_specs: Any # pytree of PartitionSpecs
def sharding(self, spec):
return NamedSharding(self.mesh, spec)
def params_shardings(self):
return jax.tree.map(self.sharding, self.params_specs)
@cached_property
def fwd(self):
raw_fwd = self.raw_fwd
@partial(
mm.jit,
in_shardings=(
self.params_shardings(),
self.sharding(P()),
self.sharding(P()),
),
out_shardings=self.sharding(P()),
)
def _fwd(params, acts, ys):
return raw_fwd(params, acts, ys)
return _fwd
@cached_property
def grad_init(self):
@partial(
mm.jit,
in_shardings=(self.params_shardings(),),
out_shardings=self.params_shardings(),
)
def _grad_init(params):
return jax.tree.map(jnp.zeros_like, params)
return _grad_init
@cached_property
def bwd_and_grad_acc(self):
raw_fwd = self.raw_fwd
@partial(
mm.jit,
in_shardings=(
self.params_shardings(),
self.sharding(P()),
self.sharding(P()),
self.params_shardings(),
self.sharding(P()),
),
out_shardings=(
self.params_shardings(),
self.sharding(P()),
),
)
def _bwd_and_grad_acc(params, fwd_activation, ys, grads_acc, activation):
with jax.named_scope('bwd'):
fwd_with_ys = lambda params, xs: raw_fwd(params, xs, ys)
_, bwd = jax.vjp(fwd_with_ys, params, fwd_activation)
grads, activation = bwd(activation)
with jax.named_scope('grad-acc'):
grads = jax.tree.map(jnp.add, grads_acc, grads)
return grads, activation
return _bwd_and_grad_acc
@cached_property
def update(self):
@partial(
mm.jit,
in_shardings=(
self.params_shardings(),
self.params_shardings(),
),
out_shardings=self.params_shardings(),
)
def _update(params, grads):
return jax.tree.map(lambda v, dv: v - dv * LR, params, grads)
return _update
def print_sharding(prefix, arr):
sharding_str = str(arr.sharding)
if hasattr(arr.sharding, 'mesh'):
mesh_str = str(arr.sharding.mesh.devices).replace('\n', ' ')
sharding_str += f' / {mesh_str}'
print(f'{prefix} {sharding_str}')
def transfer(arr, stage, spec):
sharding = stage.sharding(spec)
return mm.device_put(arr, device=sharding)
def _mpmd_constant(stage, spec, shape, value, dtype=jnp.float32):
# TODO: Better support for constants in mm (bake into jit executable?)
return transfer(jnp.full(shape, value, dtype=dtype), stage, spec)
def stages_step_fn(stages, num_mubatches, params_by_stage, xs, ys):
num_stages = len(stages)
### Schedule
tasks = [
(mubatch_idx, stage_idx, is_fwd)
for stage_idx in range(num_stages)
for mubatch_idx in range(num_mubatches)
for is_fwd in (False, True)
]
# We want to be careful with the order in which we enqueue work, since
# a single process is managing multiple devices.
# Assuming a GPipe-like schedule we traverse tasks in the following order:
# t=0 t=1 t=2 t=3 t=4 t=5 t=6
# stage=0 1 2 4 7
# stage=1 3 5 8 11
# stage=2 6 9 12 14
# stage=3 10 13 15 16
def task_key(task):
mubatch_idx, stage_idx, is_bwd = task
if is_bwd:
stage_idx = -stage_idx
return (is_bwd, mubatch_idx + stage_idx, stage_idx)
tasks.sort(key=task_key)
### State
# fwd_input : (mubatch_idx, stage_idx) -> input/activation
# TODO: Actually slice the input data into separate microbatches
fwd_input = {
(mubatch_idx, 0): xs
for mubatch_idx in range(num_mubatches)
}
# bwd_input : (mubatch_idx, stage_idx) -> activation
bwd_input = {
(mubatch_idx, num_stages-1): _mpmd_constant(
stages[-1], P(), shape=(), value=1.0)
for mubatch_idx in range(num_mubatches)
}
# grads_by_stage : stage_idx -> grads
grads_by_stage = []
for stage_idx, stage in enumerate(stages):
with profile_utils.annotate(f'grad-init{stage_idx}', color='red'):
grads_by_stage.append(stage.grad_init(params_by_stage[stage_idx]))
# loss : mubatch_idx -> loss
# TODO: Add a leading mubatch dim to loss instead of making it a list
loss = [None] * num_mubatches
def maybe_ys(stage_idx):
if stage_idx == num_stages-1:
return ys
else:
return _mpmd_constant(stages[stage_idx], P(), shape=(), value=jnp.nan)
### Microbatched forward+backward
for mubatch_idx, stage_idx, is_bwd in tasks:
stage = stages[stage_idx]
params = params_by_stage[stage_idx]
fwd_bwd_str = 'B' if is_bwd else 'F'
with profile_utils.annotate(
f'mub{mubatch_idx}/{fwd_bwd_str}{stage_idx}', color='cyan'
):
curr_id = (mubatch_idx, stage_idx)
if not is_bwd:
### Forward
succ_id = (mubatch_idx, stage_idx+1)
activation = stage.fwd(
params,
fwd_input[curr_id],
maybe_ys(stage_idx),
)
if stage_idx+1 < num_stages:
with profile_utils.annotate(
f'Tx mub{mubatch_idx} to {stage_idx+1}', color='yellow',
):
fwd_input[succ_id] = transfer(
activation,
stages[stage_idx+1],
P(),
)
else:
loss[mubatch_idx] = activation
else:
### Backward
succ_id = (mubatch_idx, stage_idx-1)
grads_by_stage[stage_idx], activation = stage.bwd_and_grad_acc(
params,
fwd_input.pop(curr_id), # NB: Frees activation afterwards.
maybe_ys(stage_idx),
grads_by_stage[stage_idx],
bwd_input.pop(curr_id), # NB: Frees activation afterwards.
)
if stage_idx-1 >= 0:
with profile_utils.annotate(
f'Tx mub{mubatch_idx} to {stage_idx-1}', color='yellow',
):
bwd_input[succ_id] = transfer(
activation,
stages[stage_idx-1],
P(),
)
### Update params
for stage_idx, stage in enumerate(stages):
with profile_utils.annotate(f'U{stage_idx}', color='green'):
params_by_stage[stage_idx] = stage.update(
params_by_stage[stage_idx],
grads_by_stage[stage_idx],
)
return loss, params_by_stage
### Example usage
def example_pp(num_processes, process_id):
assert jax.device_count() == 8
NUM_STAGES = 4
NUM_MUBATCHES = 4
# FIXME: Support stages spread across multiple processes.
assert NUM_STAGES % num_processes == 0
LAYER_SIZE = 8192
# a) Several layers per stage, little communication (32MB activations)
# NUM_LAYERS = NUM_STAGES * 16
# BATCH_SIZE = 1024
# b) One layer per stage, more communication (512MB activations)
NUM_LAYERS = NUM_STAGES
BATCH_SIZE = 1024 * 16
ENABLE_TP = True
@jax.jit
def mlp(params, xs):
for WA, WB in params:
xs = xs @ WA @ WB
return xs
@jax.jit
def mse(act, ys):
return jnp.mean(jnp.square(act - ys))
def init_params(key):
params = []
for _ in range(NUM_LAYERS):
key, key_WA, key_WB = jax.random.split(key, 3)
WA = jax.random.normal(key_WA, (LAYER_SIZE, LAYER_SIZE))
WB = jax.random.normal(key_WB, (LAYER_SIZE, LAYER_SIZE))
params.append((WA, WB))
return params, key
def shard_params_by_stage(params, stages):
num_per_stage, rem = divmod(len(params), len(stages))
assert num_per_stage > 0
assert rem == 0
params_by_stage = [
jax.tree.map(
lambda arr, spec: transfer(arr, stage, spec),
params[num_per_stage*stage_idx:num_per_stage*(stage_idx+1)],
stage.params_specs,
)
for stage_idx, stage in enumerate(stages)
]
return params_by_stage
# Define stages -- two devices per stage (running fully-replicated).
num_devices_per_stage = jax.device_count() // NUM_STAGES
stages = []
for i in range(NUM_STAGES):
devices = jax.devices()[num_devices_per_stage*i : num_devices_per_stage*(i+1)]
assert all(d.process_index == devices[0].process_index for d in devices)
mesh = Mesh(np.asarray(devices), ('model',))
if i == NUM_STAGES - 1:
fwd = lambda params, xs, ys: mse(mlp(params, xs), ys)
else:
fwd = lambda params, xs, _ys: mlp(params, xs)
num_layers_per_stage = NUM_LAYERS // NUM_STAGES
if ENABLE_TP:
params_specs = [(P(None, 'model'), P('model', None))] * num_layers_per_stage
else:
params_specs = [(P(), P())] * num_layers_per_stage
stages.append(Stage(fwd, mesh, params_specs))
def step_fn(params_by_stage, xs, ys):
return stages_step_fn(stages, NUM_MUBATCHES, params_by_stage, xs, ys)
key = jax.random.PRNGKey(0)
params, key = init_params(key)
params_by_stage = shard_params_by_stage(params, stages)
# Just keep reusing one batch, so we don't have to worry about infeed.
key, key_xs = jax.random.split(key)
xs_batch = jax.random.uniform(
key_xs,
(BATCH_SIZE, LAYER_SIZE),
)
ys_batch = 7 * xs_batch
xs_batch = transfer(xs_batch, stages[0], P())
ys_batch = transfer(ys_batch, stages[-1], P())
NUM_STEPS = 50
NUM_STEPS_PROFILED = 3
for i in range(NUM_STEPS):
print(f'===== STEP {i} {process_id=} =====')
if i == NUM_STEPS - NUM_STEPS_PROFILED:
profile_utils.maybe_start_profile(f"pp_trace/p{process_id}")
with profile_utils.annotate(f'step{i}', color='white'):
loss, params_by_stage = step_fn(params_by_stage, xs_batch, ys_batch)
if __name__ == '__main__':
import sys
num_processes = 4
if len(sys.argv) >= 2:
num_processes = int(sys.argv[1])
success = launch_utils.launch_example(num_processes, example_pp)
sys.exit(0 if success else 1)

View File

@ -0,0 +1,106 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs some simple mm operations on varying numbers of processes."""
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import _private_mm as mm
from jax.experimental._private_mm.examples import launch_utils
def make_two_meshes():
devices = jax.devices()
num_devices = len(devices)
assert num_devices % 2 == 0
mesh1 = Mesh(devices[:num_devices//2], ('data',))
mesh2 = Mesh(devices[num_devices//2:], ('data',))
sharding1 = NamedSharding(mesh1, P('data'))
sharding2 = NamedSharding(mesh2, P('data'))
return mesh1, mesh2, sharding1, sharding2
def test_device_put_uncommitted(_num_processes, process_id):
_, _, sharding1, sharding2 = make_two_meshes()
x = mm.device_put(jnp.ones((16,16)), sharding1)
x.block_until_ready()
def test_device_put_across_meshes(_num_processes, process_id):
_, _, sharding1, sharding2 = make_two_meshes()
x = mm.device_put(jnp.ones((16,16)), sharding1)
y = mm.device_put(x, sharding2)
if y.is_fully_remote:
y.block_until_ready()
else:
np.testing.assert_array_equal(y.jax_array, jnp.ones((16,16)))
def test_jit_and_transfer(_num_processes, process_id):
_, _, sharding1, sharding2 = make_two_meshes()
x1 = mm.device_put(jnp.ones((16,16)), sharding1)
x2 = mm.jit(lambda x: x + 1, out_shardings=sharding1)(x1)
y1 = mm.device_put(x2, sharding2)
y2 = mm.jit(lambda x: x * 2, out_shardings=sharding2)(y1)
if y2.is_fully_remote:
y2.block_until_ready()
else:
np.testing.assert_array_equal(y2.jax_array, jnp.full((16,16), 4))
def run_test(num_processes, test_fun, name):
print(f' - {name} ... ', end='', flush=True)
success = launch_utils.launch_example(num_processes, test_fun)
if success:
print('OK')
else:
print('FAIL')
return success
def run_tests():
# For 1 process mm.device_puts simply reduce to jax.device_puts.
# For 2 processes and tests involving two meshes we require NCCL comms,
# but all devices of a mesh are managed by the same process.
# For 4 processes and tests involving two meshes we additionally have to
# deal with devices of a mesh being managed by multipel processes.
# (The latter currently doesn't work.)
NUM_PROCESSESS = (1, 2, 4)
TESTS = [
('device_put_uncommitted', test_device_put_uncommitted),
('device_put_across_meshes', test_device_put_across_meshes),
('jit_and_transfer', test_jit_and_transfer),
]
num_failures = 0
for num_processes in NUM_PROCESSESS:
print(f'=== {num_processes=} ===')
for test_name, test_fun in TESTS:
success = run_test(num_processes, test_fun, test_name)
if not success:
num_failures += 1
if num_failures == 0:
print('All tests succeeded!')
return True
else:
print(f'{num_failures} tests failed!')
return False
if __name__ == '__main__':
import sys
success = run_tests()
sys.exit(0 if success else 1)

View File

@ -0,0 +1,95 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to launch multi-process JAX examples on a single host."""
from functools import partial
import multiprocessing
def init_multi_process_local(num_processes, process_id, num_devices):
assert 0 <= process_id < num_processes
# Assume all processes run on a single node
assert num_devices % num_processes == 0
num_devices_per_process = num_devices // num_processes
local_device_ids = [
process_id*num_devices_per_process + i
for i in range(num_devices_per_process)
]
import jax
jax.distributed.initialize(
coordinator_address="localhost:1234",
num_processes=num_processes,
process_id=process_id,
local_device_ids=local_device_ids,
)
# Needs to be a top-level function to pickle as part of multiprocessing.
def _entrypoint(num_processes, process_id, user_main, num_devices):
# Only import these in the subprocess, not the launcher process.
import jax.experimental.multihost_utils
from jax.experimental._private_mm import profile_utils
init_multi_process_local(num_processes, process_id, num_devices)
jax.experimental.multihost_utils.sync_global_devices("start_user_main")
user_main(num_processes, process_id)
profile_utils.maybe_stop_profile()
def launch_example(num_processes, user_main, num_devices=8):
"""
A launcher for examples running across multiple processes on a single node.
Returns true iff all processes exited successfully.
Example code my_example.py:
def my_example(num_processes, process_id):
# Do some distributed JAX stuff.
...
if __name__ == '__main__':
import sys
num_processes = int(sys.argv[1])
launch_utils.launch_example(num_processes, my_example)
Usage:
# Run without profiling
my_example.py 4
# Run with jax.profiler + annotations
PROFILE=jax python3 my_example.py 4
# Run with nsys profiling + annotations
PROFILE=nsys nsys profile --output my_example.nsys-rep --cpuctxsw=none --trace=cublas,cuda,cudnn,cusolver,nvtx,osrt,python-gil --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop --cuda-graph-trace=node --python-sampling=true python3 my_example.py 4
"""
assert num_processes > 0
# Spawn subprocesses to avoid timeouts when profiling using nsys.
ctx = multiprocessing.get_context('spawn')
ps = [
ctx.Process(
target=partial(
_entrypoint,
num_processes,
process_id,
user_main,
num_devices,
),
name=f'example_proc{process_id}',
)
for process_id in range(num_processes)
]
for p in ps:
p.start()
for p in ps:
p.join()
return all(p.exitcode == 0 for p in ps)

View File

@ -0,0 +1,285 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Explicit NCCL communicators (via CuPy) integrated with JAX. Based on the
communication library used in JaxPP (https://arxiv.org/abs/2412.14374).
Requires `pip install cupy-cuda12x`.
"""
import enum
import pickle
from collections import OrderedDict
from functools import cached_property
try:
import cupy # type: ignore[import-not-found]
from cupy.cuda import nccl # type: ignore[import-not-found]
# CuPy NCCL utils from https://github.com/cupy/cupy/blob/118ade4a146d1cc68519f7f661f2c145f0b942c9/cupyx/distributed/_nccl_comm.py#L46-L55
_nccl_dtypes = {
"b": nccl.NCCL_INT8,
"B": nccl.NCCL_UINT8,
"i": nccl.NCCL_INT32,
"I": nccl.NCCL_UINT32,
"l": nccl.NCCL_INT64,
"L": nccl.NCCL_UINT64,
"q": nccl.NCCL_INT64,
"Q": nccl.NCCL_UINT64,
"e": nccl.NCCL_FLOAT16,
"f": nccl.NCCL_FLOAT32,
"d": nccl.NCCL_FLOAT64,
# Size of array will be doubled
"F": nccl.NCCL_FLOAT32,
"D": nccl.NCCL_FLOAT64,
}
except ImportError:
cupy = None
nccl = None
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as xe
from jax._src import array
from jax._src.op_shardings import are_op_shardings_equal
def _get_nccl_dtype_and_count(arr, count=None):
dtype = arr.dtype.char
if dtype not in _nccl_dtypes:
raise TypeError(f"Unknown dtype {arr.dtype} for NCCL")
nccl_dtype = _nccl_dtypes[dtype]
if count is None:
count = arr.size
if dtype in "FD":
return nccl_dtype, 2 * count
return nccl_dtype, count
def get_distributed_client() -> xe.DistributedRuntimeClient:
from jax._src.distributed import global_state
assert isinstance(global_state.client, xe.DistributedRuntimeClient)
return global_state.client
class UniqueDevices(tuple[jax.Device, ...]):
def __new__(cls, *args):
return super().__new__(cls, sorted(set(args), key=lambda d: d.id))
@cached_property
def ranks(self):
return OrderedDict((d, idx) for idx, d in enumerate(self))
@property
def leader(self):
return self[0]
@cached_property
def key(self) -> str:
return ",".join(str(d.id) for d in self)
local_comms: dict = {}
def get_or_create_comm(devs: UniqueDevices):
TIMEOUT = 5_000
comm = local_comms.get(devs)
my_process_index = jax.process_index()
if comm is None:
if devs.leader.process_index == my_process_index:
nccl_id = nccl.get_unique_id()
get_distributed_client().key_value_set_bytes(
devs.key, pickle.dumps(nccl_id)
)
else:
nccl_id = get_distributed_client().blocking_key_value_get_bytes(
devs.key, TIMEOUT
)
nccl_id = pickle.loads(nccl_id)
nccl.groupStart()
for d in devs:
if d.process_index == my_process_index:
with cupy.cuda.Device(d.local_hardware_id):
comm = nccl.NcclCommunicator(len(devs), nccl_id, devs.ranks[d])
nccl.groupEnd()
local_comms[devs] = comm
return comm
local_streams: dict = {}
class OpT(enum.Enum):
SEND = 0
RECV = 1
def get_or_create_stream(op: OpT, local_device: jax.Device):
# XXX: I think this can be one stream per local_device for this specific example.
# It depends on the use case
stream = local_streams.get((op, local_device))
if stream is None:
with cupy.cuda.Device(local_device.local_hardware_id):
stream = cupy.cuda.Stream(non_blocking=True)
local_streams[local_device] = stream
return stream
def shardings_are_compatible(
self: jax.sharding.Sharding, other: jax.sharding.Sharding, ndim: int
):
# NOTE: Variant of `jax.sharding.Sharding.is_equivalent_to` that skips _internal_device_list check
return (
are_op_shardings_equal(
self._to_xla_hlo_sharding(ndim), other._to_xla_hlo_sharding(ndim)
)
# and self._internal_device_list == other._internal_device_list # type: ignore
and self.memory_kind == other.memory_kind
)
## API
def send_or_recv(
x: jax.Array,
tgt_sharding: jax.sharding.Sharding,
src_sharding: jax.sharding.Sharding | None = None,
):
"""
When `src_sharding is None` this function corresponds to a send and
`x.sharding` must be equal to `tgt_sharding`.
When `src_sharding is not None` this function corresponds to a receive
and `x` will be consumed, i.e. it's unsafe to use `x` after `send_or_recv(x, src_sharding=...)`.
`x` can be a "global" array spanning multiple processes/hosts.
In that case, this process will send/receive only its corresponding addressable_shards
"""
if src_sharding is None:
is_send = True
other_sharding = tgt_sharding
else:
is_send = False
other_sharding = src_sharding
if not is_send:
# XXX: x.sharding and tgt_sharding must be equal since this is a recv.
# This seems redundant to me. Not sure what the final version from Skye
# will look like.
assert x.sharding == tgt_sharding
# TODO: implement reshard for 4 devs -> 2 devs or 2->4 reshards
assert shardings_are_compatible(x.sharding, other_sharding, x.ndim), \
f'incompatible shardings: {x.sharding=} vs {other_sharding=}'
# Create communicators lazily as needed. This can be a separate "setup function"
for pair in zip(
x.sharding._device_assignment,
other_sharding._device_assignment,
strict=True,
):
if pair[0].process_index == jax.process_index():
_ = get_or_create_comm(UniqueDevices(*pair))
shards_by_device = {shard.device: shard for shard in x.addressable_shards}
cpy_arrays_and_streams = []
# FIXME: maybe narrow `nccl_group_{start,end}` scope by first accumulating
# arguments in a list and then performing the operation
nccl.groupStart()
for x_device, other_device in zip(
x.sharding._device_assignment,
other_sharding._device_assignment,
strict=True,
):
if x_device.process_index == jax.process_index():
shard = shards_by_device[x_device]
stream = get_or_create_stream(OpT.SEND if is_send else OpT.RECV, x_device)
# FIXME: cupy doesn't support bf16. Use capsule/ctype APIs
cpy_arr = cupy.from_dlpack(
jax.dlpack.to_dlpack(shard.data, stream=stream.ptr)
)
cpy_arrays_and_streams.append((cpy_arr, stream))
nccl_dtype, count = _get_nccl_dtype_and_count(cpy_arr)
key = UniqueDevices(x_device, other_device)
comm = get_or_create_comm(key)
with cupy.cuda.Device(x_device.local_hardware_id):
op = comm.send if is_send else comm.recv
op(
cpy_arr.data.ptr,
count,
nccl_dtype,
key.ranks[other_device],
stream.ptr,
)
nccl.groupEnd()
# NOTE: since communicators are blocking, after the group_end operation
# above, all the send/recvs have been enqueued into the stream. Therefore,
# we can record events on the stream
# XXX: I don't like different return types below, however I am not sure
# what's a better alternative given we want a "symmetric"
# `send_or_recv` API
if is_send:
def wait():
for _, stream in cpy_arrays_and_streams:
stream.synchronize()
# NOTE: Keep the objects below alive just in case they are not
# deleted/overwritten by XLA while in use
return (x, cpy_arrays_and_streams)
return wait
else:
def enqueue_wait():
jax_single_arrays = []
for x_device, (cpy_arr, stream) in zip(
x.sharding._device_assignment, cpy_arrays_and_streams, strict=True
):
with cupy.cuda.Device(x_device.local_hardware_id):
event = stream.record()
ready_events_stream = (
x_device.get_stream_for_external_ready_events()
)
cupy.cuda.ExternalStream(ready_events_stream).wait_event(event)
jax_sda = jnp.array(
jax._src.lib.xla_client._xla.dlpack_managed_tensor_to_buffer(
cpy_arr.toDlpack(),
x_device,
ready_events_stream,
),
copy=True, # XXX: Just to be safe
)
jax_single_arrays.append(jax_sda)
return array.ArrayImpl(
x.aval,
x.sharding,
jax_single_arrays,
committed=True,
# NOTE: _skip_checks can be set to True however since this happens
# asynchronously there's no perf harm to keep it False.
_skip_checks=False,
)
return enqueue_wait

View File

@ -0,0 +1,300 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Poor-man's MPMD for JAX."""
from dataclasses import dataclass
from functools import cached_property, lru_cache, partial, wraps
from typing import Callable
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, Sharding, SingleDeviceSharding
from jax._src.tree_util import broadcast_prefix, prefix_errors, tree_leaves_with_path
from jax.experimental._private_mm import mini_dime
@dataclass
class MpmdArray:
"""A generalization of jax.Array that also supports fully remote arrays."""
aval: jax.core.ShapedArray
sharding: Sharding
_complete: Callable[[], jax.Array | tuple] | None
_result: jax.Array | tuple | None = None
def __repr__(self):
remote_str = ', fully-remote' if self.is_fully_remote else ''
return (
f'MpmdArray({self.aval}, sharding={self.sharding}, '
f'devices={self.sharding.mesh.devices}{remote_str})'
)
def block_until_ready(self):
if self._complete is None:
# Already awaited.
assert self._result is not None
return
result = self._complete()
if isinstance(result, jax.Array):
# Recv result, store array.
self._result = result
else:
# No-op result or send result. Drop objects kept alive, but register
# completion.
self._result = ()
# Drop the closure.
self._complete = None
return self
@cached_property
def is_fully_remote(self):
return is_fully_remote_sharding(self.sharding)
@property
def jax_array(self):
if self.is_fully_remote:
raise ValueError('cannot convert fully-remote MpmdArray to jax.Array')
self.block_until_ready()
assert isinstance(self._result, jax.Array), (
'expected non-fully-remote MpmdArray to hold some local data, but got: '
f'{self._result} (mesh devices: {self.sharding.mesh.devices})'
)
return self._result
@property
def shape(self):
return self.aval.shape
@property
def dtype(self):
return self.aval.dtype
JaxOrMpmdArray = jax.Array | MpmdArray
def is_local_device(device) -> bool:
return device.process_index == jax.process_index()
def is_fully_remote_sharding(sharding: Sharding) -> bool:
# TODO: Handle shardings other than NamedSharding?
assert isinstance(sharding, NamedSharding)
return not any(map(is_local_device, sharding.mesh.devices.flat))
def is_fully_local_sharding(sharding: Sharding) -> bool:
# TODO: Handle shardings other than NamedSharding?
assert isinstance(sharding, NamedSharding)
return all(map(is_local_device, sharding.mesh.devices.flat))
def is_fully_remote_array(arr: JaxOrMpmdArray) -> bool:
return isinstance(arr, MpmdArray) and arr.is_fully_remote
def as_jax_array(arr: JaxOrMpmdArray) -> jax.Array:
if isinstance(arr, MpmdArray):
return arr.jax_array
assert isinstance(arr, jax.Array)
return arr
def fix_sharding(sharding: Sharding) -> Sharding:
# FIXME: During jax.device_put(..., sharding) jaxlib/XLA fills in a memory
# kind if none was explicitly given. We don't always call into
# jax.device_put here, but we want to mirror this behavior so that even
# processes that don't call jax.device_put end up with the exact same
# metadata. (The bandaid below is likely incomplete.)
if sharding.memory_kind is None:
sharding = sharding.with_memory_kind('device')
return sharding
@lru_cache
def recv_buf_factory(shape, dtype, tgt_sharding):
@partial(jax.jit, out_shardings=tgt_sharding)
def recv_buf_init():
return jnp.zeros(shape, dtype)
return recv_buf_init
# TODO: Generalize mm.device_put to mix jax.device_put, send and recv as
# needed. For the moment, we only allow cases that neatly fall into one of the
# above three cases, i.e. the present process either issue a jax.device_put,
# a NCCL send or a NCCL recv. This means that every submesh (e.g. a stage) needs
# to be managed by a single process for now.
def device_put(arr: JaxOrMpmdArray, device: Sharding) -> MpmdArray:
assert isinstance(device, Sharding)
tgt_sharding = fix_sharding(device)
src_sharding = fix_sharding(arr.sharding)
def complete_with(complete):
return MpmdArray(
aval=arr.aval,
sharding=tgt_sharding,
_complete=complete,
)
if is_fully_remote_array(arr):
if is_fully_remote_sharding(tgt_sharding):
# FullyRemote->FullyRemote: Nothing to be done.
return complete_with(lambda: ())
else:
# FullyRemote->NonFullyRemote: Recv.
# NOTE: We run the same jitted fun on each participating device,
# rather than jax.device_put(jnp.zeros(...), tgt_sharding). The
# latter produces jnp.zeros first on one local device and then P2P-
# copies to the others, which anecdotally appears to be slower, but
# also litters the profile, so we avoid it.
recv_buf = recv_buf_factory(
arr.aval.shape,
arr.aval.dtype,
tgt_sharding,
)()
return complete_with(
mini_dime.send_or_recv(
recv_buf,
tgt_sharding,
src_sharding,
)
)
# arr has some locally-addressable shards.
jax_array = as_jax_array(arr)
if jax_array.committed:
if is_fully_remote_sharding(tgt_sharding):
# NonFullyRemote->FullyRemote: Send.
# FIXME: Should force completion at some point.
return complete_with(
mini_dime.send_or_recv(
jax_array,
tgt_sharding,
)
)
elif (
is_fully_local_sharding(src_sharding) and
is_fully_local_sharding(tgt_sharding)
):
# NonFullyRemote->NonFullyRemote: jax.device_put
new_jax_array = jax.device_put(jax_array, tgt_sharding)
return complete_with(lambda: new_jax_array)
else:
# NOTE: We exclude cases of NonFullyRemote -> NonFullyRemote
# which would require a mix of jax.device_put, Send and Recv.
raise NotImplementedError('unsupported transfer')
else:
# Uncommitted array.
assert isinstance(jax_array.sharding, SingleDeviceSharding)
if is_fully_remote_sharding(tgt_sharding):
# Uncommitted->FullyRemote: Nothing to be done
return complete_with(lambda: ())
else:
# Uncommitted->NonFullyRemote: jax.device_put
# NOTE: Uncommitted arrays arise when the user hasn't yet specified
# a device or sharding, so the current (single-device) sharding is
# somewhat arbitrary.
# An important assumption here is that, though said device will vary
# from process to process, we expect all of the processes to have
# the same values.
#
# Now we'd like to do something like
# new_jax_array = jax.device_put(jax_array, tgt_sharding)
# where we'd expect jax.device_put to simply simply transfer from
# the current local single device to all the other relevant local
# devices.
#
# This unfortunately doesn't work, because jax.device_put will check
# the above assumption of same-values-everywhere by introducing a
# broadcast from process 0 to all others. But in an MPMD program
# only a subset of processes will participate in any given
# device_put, so this might lead to hangs!
#
# We could likely work around this by doing appropriate device_puts
# with single-device shardings and subsequently using
# jax.make_array_from_single_device_arrays to build a global array.
if not is_fully_local_sharding(tgt_sharding):
raise NotImplementedError('unsupported transfer')
new_jax_array = jax.device_put(jax_array, tgt_sharding)
return complete_with(lambda: new_jax_array)
def jit(*args, **kwargs):
if (out_shardings := kwargs.get('out_shardings')) is None:
raise ValueError('missing out_shardings')
fun = jax.jit(*args, **kwargs)
@wraps(fun)
def wrapped(*in_vals):
first_fully_remote_input = next(
(
(path, in_val)
for path, in_val in tree_leaves_with_path(in_vals)
if is_fully_remote_array(in_val)
),
None,
)
# This computation does not concern us, return fully-remote arrays.
if first_fully_remote_input is not None:
out_shape_dtypes = jax.eval_shape(fun, *in_vals)
# Allow out_shardings to be a prefix tree
try:
out_shardings_flat = broadcast_prefix(
out_shardings,
out_shape_dtypes,
is_leaf=lambda x: x is None, # FIXME: Correct?
)
except ValueError:
e, *_ = prefix_errors(out_shardings, out_shape_dtypes)
raise e('mm.jit out_shardings') from None
out_shardings_full = jax.tree.unflatten(
jax.tree.structure(out_shape_dtypes),
out_shardings_flat,
)
# Make an MpmdArray for every out value
def make_fully_remote_output(shape_dtype, sharding):
if not is_fully_remote_sharding(sharding):
path, in_val = first_fully_remote_input
raise ValueError(
'mm.jit produces a non-fully-remote output, but '
f'was invoked on fully-remote input: {in_val} @ {path}')
return MpmdArray(
aval=jax.core.ShapedArray(
shape_dtype.shape,
shape_dtype.dtype,
),
sharding=sharding,
_complete=lambda: (),
)
return jax.tree.map(
make_fully_remote_output,
out_shape_dtypes,
out_shardings_full,
)
# This computations concerns us, run the jax.jit-ed function.
in_vals = jax.tree.map(as_jax_array, in_vals)
out_vals = fun(*in_vals)
return jax.tree.map(
lambda jax_array: MpmdArray(
jax_array.aval,
jax_array.sharding,
lambda: jax_array,
),
out_vals,
)
return wrapped

View File

@ -0,0 +1,77 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for profiling, abstracting over jax.profiler and nsys."""
from contextlib import contextmanager
import os
import jax
def get_profiling_mode() -> str | None:
mode = os.environ.get('PROFILE')
if mode is not None:
mode = mode.lower()
assert mode in ('jax', 'nsys')
return mode
return None
if get_profiling_mode() == 'nsys':
from ctypes import cdll
libcudart = cdll.LoadLibrary('libcudart.so')
import nvtx # type: ignore[import-not-found]
def maybe_start_profile(path):
profiling_mode = get_profiling_mode()
if profiling_mode is None:
pass
elif profiling_mode == 'jax':
jax.profiler.start_trace(path)
elif profiling_mode == 'nsys':
libcudart.cudaProfilerStart()
else:
assert False
def maybe_stop_profile():
profiling_mode = get_profiling_mode()
if profiling_mode is None:
pass
elif profiling_mode == 'jax':
try:
jax.profiler.stop_trace()
except RuntimeError as e:
if e.args[0] != 'No profile started':
raise
elif profiling_mode == 'nsys':
libcudart.cudaProfilerStop()
else:
assert False
@contextmanager
def annotate(label, color=None):
profiling_mode = get_profiling_mode()
if profiling_mode is None:
yield
elif profiling_mode == 'jax':
with jax.profiler.TraceAnnotation(label):
yield
elif profiling_mode == 'nsys':
with nvtx.annotate(label, color=color or 'red'):
yield
else:
assert False