mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #26564 from gspschmid:gschmid/mini_mpmd
PiperOrigin-RevId: 730912043
This commit is contained in:
commit
a6b8384aed
41
jax/experimental/_private_mm/BUILD
Normal file
41
jax/experimental/_private_mm/BUILD
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "private_mm",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [":private_mm_internal"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "private_mm_internal",
|
||||
srcs = [
|
||||
"example_utils.py",
|
||||
"mini_dime.py",
|
||||
"mm.py",
|
||||
"profile_utils.py",
|
||||
],
|
||||
tags = ["pytype_unchecked_annotations"],
|
||||
deps = [
|
||||
"//jax",
|
||||
],
|
||||
)
|
20
jax/experimental/_private_mm/__init__.py
Normal file
20
jax/experimental/_private_mm/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental._private_mm.mm import (
|
||||
device_put as device_put,
|
||||
jit as jit,
|
||||
MpmdArray as MpmdArray
|
||||
)
|
||||
from jax.experimental._private_mm import profile_utils as profile_utils
|
79
jax/experimental/_private_mm/examples/example_basic.py
Normal file
79
jax/experimental/_private_mm/examples/example_basic.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A basic educational example."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
|
||||
from jax.experimental import _private_mm as mm
|
||||
from jax.experimental._private_mm.examples import launch_utils
|
||||
|
||||
|
||||
def step():
|
||||
devices = jax.devices()
|
||||
mesh1 = Mesh(devices[:4], ('data',))
|
||||
mesh2 = Mesh(devices[4:], ('data',))
|
||||
|
||||
sharding1 = NamedSharding(mesh1, P('data'))
|
||||
sharding2 = NamedSharding(mesh2, P('data'))
|
||||
|
||||
shape = (512, 2**20)
|
||||
|
||||
@partial(mm.jit, in_shardings=sharding1, out_shardings=sharding1)
|
||||
def stage1(x):
|
||||
return x + 1
|
||||
|
||||
@partial(mm.jit, in_shardings=sharding2, out_shardings=sharding2)
|
||||
def stage2(x):
|
||||
return x * 2
|
||||
|
||||
a0: mm.MpmdArray = mm.device_put(jnp.zeros(shape), sharding1)
|
||||
b0: mm.MpmdArray = mm.device_put(jnp.ones(shape), sharding1)
|
||||
|
||||
# Enqueue all work on [a]
|
||||
a1 = stage1(a0)
|
||||
a1 = mm.device_put(a1, sharding2)
|
||||
a2 = stage2(a1)
|
||||
|
||||
# Enqueue all work on [b]
|
||||
b1 = stage1(b0)
|
||||
b1 = mm.device_put(b1, sharding2)
|
||||
b2 = stage2(b1)
|
||||
|
||||
# Only print if a2/b2 resident (i.e., we belong to the last stage):
|
||||
if not a2.is_fully_remote:
|
||||
assert not b2.is_fully_remote
|
||||
print(a2.jax_array)
|
||||
print(b2.jax_array)
|
||||
|
||||
|
||||
def example_basic(num_processes, process_id):
|
||||
assert jax.device_count() == 8
|
||||
# FIXME: Support stages spread across multiple processes.
|
||||
assert 2 % num_processes == 0
|
||||
|
||||
for i in range(3):
|
||||
step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
num_processes = 2
|
||||
if len(sys.argv) >= 2:
|
||||
num_processes = int(sys.argv[1])
|
||||
success = launch_utils.launch_example(num_processes, example_basic)
|
||||
sys.exit(0 if success else 1)
|
191
jax/experimental/_private_mm/examples/example_overlap.py
Normal file
191
jax/experimental/_private_mm/examples/example_overlap.py
Normal file
@ -0,0 +1,191 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An example showcasing overlap on a (forward-only) PP-like workload."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
|
||||
from jax.experimental import _private_mm as mm
|
||||
from jax.experimental._private_mm import profile_utils
|
||||
from jax.experimental._private_mm.examples import launch_utils
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stage:
|
||||
fwd: Callable[[Any, Any], Any] # (params, acts) -> acts
|
||||
mesh: Mesh
|
||||
|
||||
|
||||
def transfer(arr, stage):
|
||||
sharding = NamedSharding(stage.mesh, P()) # just replicate
|
||||
return mm.device_put(arr, device=sharding)
|
||||
|
||||
|
||||
def stages_step_fn(stages, num_mubatches, params_by_stage, xs):
|
||||
# One task per mubatch and stage (e.g. forward stages only)
|
||||
tasks = [
|
||||
(mubatch_idx, stage_idx)
|
||||
for stage_idx in range(len(stages))
|
||||
for mubatch_idx in range(num_mubatches)
|
||||
]
|
||||
# We want to be careful with the order in which we enqueue work, since
|
||||
# a single process is managing multiple devices.
|
||||
# Assuming a GPipe-like schedule we traverse tasks in the following order:
|
||||
# t=0 t=1 t=2 t=3 t=4 t=5 t=6
|
||||
# stage=0 1 2 4 7
|
||||
# stage=1 3 5 8 11
|
||||
# stage=2 6 9 12 14
|
||||
# stage=3 10 13 15 16
|
||||
def task_key(task):
|
||||
mubatch_idx, stage_idx = task
|
||||
return (mubatch_idx + stage_idx, stage_idx)
|
||||
tasks.sort(key=task_key)
|
||||
|
||||
input = {
|
||||
(mubatch_idx, 0): xs
|
||||
for mubatch_idx in range(num_mubatches)
|
||||
}
|
||||
|
||||
for task_id in tasks:
|
||||
mubatch_idx, stage_idx = task_id
|
||||
stage = stages[stage_idx]
|
||||
params = params_by_stage[stage_idx]
|
||||
with profile_utils.annotate(
|
||||
f'mub{mubatch_idx}/F{stage_idx}',
|
||||
color='cyan',
|
||||
):
|
||||
# Invoke the stage and immediately enqueue the transfer of the
|
||||
# result to the next stage. We want the transfer to be overlapped
|
||||
# with subsequent computation on the same stage.
|
||||
local_output = stage.fwd(params, input[task_id])
|
||||
if stage_idx + 1 < len(stages):
|
||||
with profile_utils.annotate(
|
||||
f'Tx mub{mubatch_idx} to {stage_idx+1}',
|
||||
color='yellow',
|
||||
):
|
||||
input[(mubatch_idx, stage_idx+1)] = transfer(
|
||||
local_output,
|
||||
stages[stage_idx+1],
|
||||
)
|
||||
|
||||
return local_output
|
||||
|
||||
|
||||
### Example usage
|
||||
|
||||
def example_overlap(num_processes, process_id):
|
||||
assert jax.device_count() == 8
|
||||
|
||||
NUM_STAGES = 4
|
||||
NUM_MUBATCHES = 4
|
||||
|
||||
# FIXME: Support stages spread across multiple processes.
|
||||
assert NUM_STAGES % num_processes == 0
|
||||
|
||||
# Takes ~5ms/stage/microbatch on H100s:
|
||||
LAYER_SIZE = 8192
|
||||
# # a) Several layers per stage, little communication (32MB activations)
|
||||
# NUM_LAYERS = NUM_STAGES * 16
|
||||
# BATCH_SIZE = 1024
|
||||
# b) One layer per stage, more communication (512MB activations)
|
||||
NUM_LAYERS = NUM_STAGES
|
||||
BATCH_SIZE = 1024 * 16
|
||||
|
||||
|
||||
def mlp(params, xs):
|
||||
for W in params:
|
||||
xs = xs @ W
|
||||
return xs
|
||||
|
||||
def init_params(key):
|
||||
params = []
|
||||
for _ in range(NUM_LAYERS):
|
||||
key, key_W = jax.random.split(key)
|
||||
params.append(jax.random.normal(key_W, (LAYER_SIZE, LAYER_SIZE)))
|
||||
return params, key
|
||||
|
||||
|
||||
# Two devices per stage (running fully-replicated)
|
||||
num_devices_per_stage = 2
|
||||
stages = []
|
||||
for i in range(NUM_STAGES):
|
||||
devices = jax.devices()[
|
||||
num_devices_per_stage*i : num_devices_per_stage*(i+1)
|
||||
]
|
||||
assert all(d.process_index == devices[0].process_index for d in devices)
|
||||
mesh = Mesh(np.asarray(devices), ('repl',))
|
||||
jitted_fun = mm.jit(
|
||||
mlp,
|
||||
in_shardings=(NamedSharding(mesh, P()), NamedSharding(mesh, P())),
|
||||
out_shardings=NamedSharding(mesh, P()),
|
||||
)
|
||||
stages.append(Stage(jitted_fun, mesh))
|
||||
|
||||
def step_fn(params_by_stage, xs):
|
||||
return stages_step_fn(stages, NUM_MUBATCHES, params_by_stage, xs)
|
||||
|
||||
|
||||
def shard_params_by_stage(params):
|
||||
num_per_stage, rem = divmod(len(params), NUM_STAGES)
|
||||
assert num_per_stage > 0
|
||||
assert rem == 0
|
||||
params_by_stage = [
|
||||
jax.tree.map(
|
||||
lambda arr: transfer(arr, stages[stage_idx]),
|
||||
params[num_per_stage*stage_idx:num_per_stage*(stage_idx+1)],
|
||||
)
|
||||
for stage_idx in range(NUM_STAGES)
|
||||
]
|
||||
return params_by_stage
|
||||
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
params, key = init_params(key)
|
||||
params_by_stage = shard_params_by_stage(params)
|
||||
|
||||
key, key_xs = jax.random.split(key)
|
||||
xs_batch = jax.random.uniform(key_xs, (BATCH_SIZE, LAYER_SIZE))
|
||||
|
||||
NUM_STEPS = 50
|
||||
NUM_STEPS_PROFILED = 3
|
||||
for i in range(NUM_STEPS):
|
||||
print(f'===== STEP {i} {process_id=} =====')
|
||||
if i == 1:
|
||||
# The overhead from compilations during warm-up ends up
|
||||
# staggering executions on devices of the same stage. The sleep
|
||||
# below allows them to catch up. In a real model collectives
|
||||
# within each stage would likely have the same effect of keeping
|
||||
# devices in sync.
|
||||
time.sleep(0.2)
|
||||
if i == NUM_STEPS - NUM_STEPS_PROFILED:
|
||||
profile_utils.maybe_start_profile(f"overlap_trace/p{process_id}")
|
||||
|
||||
xs_batch = transfer(xs_batch, stages[0])
|
||||
with profile_utils.annotate(f'step{i}', color='white'):
|
||||
xs_batch = step_fn(params_by_stage, xs_batch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
num_processes = 4
|
||||
if len(sys.argv) >= 2:
|
||||
num_processes = int(sys.argv[1])
|
||||
success = launch_utils.launch_example(num_processes, example_overlap)
|
||||
sys.exit(0 if success else 1)
|
361
jax/experimental/_private_mm/examples/example_pp.py
Normal file
361
jax/experimental/_private_mm/examples/example_pp.py
Normal file
@ -0,0 +1,361 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A toy model with MPMD pipeline parallelism."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
|
||||
from jax.experimental import _private_mm as mm
|
||||
from jax.experimental._private_mm import profile_utils
|
||||
from jax.experimental._private_mm.examples import launch_utils
|
||||
|
||||
|
||||
LR = 0.01
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Stage:
|
||||
raw_fwd: Callable[[Any, Any, Any], Any] # (params, acts, ys) -> acts
|
||||
mesh: Mesh
|
||||
params_specs: Any # pytree of PartitionSpecs
|
||||
|
||||
def sharding(self, spec):
|
||||
return NamedSharding(self.mesh, spec)
|
||||
|
||||
def params_shardings(self):
|
||||
return jax.tree.map(self.sharding, self.params_specs)
|
||||
|
||||
@cached_property
|
||||
def fwd(self):
|
||||
raw_fwd = self.raw_fwd
|
||||
|
||||
@partial(
|
||||
mm.jit,
|
||||
in_shardings=(
|
||||
self.params_shardings(),
|
||||
self.sharding(P()),
|
||||
self.sharding(P()),
|
||||
),
|
||||
out_shardings=self.sharding(P()),
|
||||
)
|
||||
def _fwd(params, acts, ys):
|
||||
return raw_fwd(params, acts, ys)
|
||||
|
||||
return _fwd
|
||||
|
||||
@cached_property
|
||||
def grad_init(self):
|
||||
@partial(
|
||||
mm.jit,
|
||||
in_shardings=(self.params_shardings(),),
|
||||
out_shardings=self.params_shardings(),
|
||||
)
|
||||
def _grad_init(params):
|
||||
return jax.tree.map(jnp.zeros_like, params)
|
||||
|
||||
return _grad_init
|
||||
|
||||
@cached_property
|
||||
def bwd_and_grad_acc(self):
|
||||
raw_fwd = self.raw_fwd
|
||||
|
||||
@partial(
|
||||
mm.jit,
|
||||
in_shardings=(
|
||||
self.params_shardings(),
|
||||
self.sharding(P()),
|
||||
self.sharding(P()),
|
||||
self.params_shardings(),
|
||||
self.sharding(P()),
|
||||
),
|
||||
out_shardings=(
|
||||
self.params_shardings(),
|
||||
self.sharding(P()),
|
||||
),
|
||||
)
|
||||
def _bwd_and_grad_acc(params, fwd_activation, ys, grads_acc, activation):
|
||||
with jax.named_scope('bwd'):
|
||||
fwd_with_ys = lambda params, xs: raw_fwd(params, xs, ys)
|
||||
_, bwd = jax.vjp(fwd_with_ys, params, fwd_activation)
|
||||
grads, activation = bwd(activation)
|
||||
with jax.named_scope('grad-acc'):
|
||||
grads = jax.tree.map(jnp.add, grads_acc, grads)
|
||||
return grads, activation
|
||||
|
||||
return _bwd_and_grad_acc
|
||||
|
||||
@cached_property
|
||||
def update(self):
|
||||
@partial(
|
||||
mm.jit,
|
||||
in_shardings=(
|
||||
self.params_shardings(),
|
||||
self.params_shardings(),
|
||||
),
|
||||
out_shardings=self.params_shardings(),
|
||||
)
|
||||
def _update(params, grads):
|
||||
return jax.tree.map(lambda v, dv: v - dv * LR, params, grads)
|
||||
|
||||
return _update
|
||||
|
||||
|
||||
def print_sharding(prefix, arr):
|
||||
sharding_str = str(arr.sharding)
|
||||
if hasattr(arr.sharding, 'mesh'):
|
||||
mesh_str = str(arr.sharding.mesh.devices).replace('\n', ' ')
|
||||
sharding_str += f' / {mesh_str}'
|
||||
print(f'{prefix} {sharding_str}')
|
||||
|
||||
|
||||
def transfer(arr, stage, spec):
|
||||
sharding = stage.sharding(spec)
|
||||
return mm.device_put(arr, device=sharding)
|
||||
|
||||
|
||||
def _mpmd_constant(stage, spec, shape, value, dtype=jnp.float32):
|
||||
# TODO: Better support for constants in mm (bake into jit executable?)
|
||||
return transfer(jnp.full(shape, value, dtype=dtype), stage, spec)
|
||||
|
||||
|
||||
def stages_step_fn(stages, num_mubatches, params_by_stage, xs, ys):
|
||||
num_stages = len(stages)
|
||||
|
||||
### Schedule
|
||||
tasks = [
|
||||
(mubatch_idx, stage_idx, is_fwd)
|
||||
for stage_idx in range(num_stages)
|
||||
for mubatch_idx in range(num_mubatches)
|
||||
for is_fwd in (False, True)
|
||||
]
|
||||
# We want to be careful with the order in which we enqueue work, since
|
||||
# a single process is managing multiple devices.
|
||||
# Assuming a GPipe-like schedule we traverse tasks in the following order:
|
||||
# t=0 t=1 t=2 t=3 t=4 t=5 t=6
|
||||
# stage=0 1 2 4 7
|
||||
# stage=1 3 5 8 11
|
||||
# stage=2 6 9 12 14
|
||||
# stage=3 10 13 15 16
|
||||
def task_key(task):
|
||||
mubatch_idx, stage_idx, is_bwd = task
|
||||
if is_bwd:
|
||||
stage_idx = -stage_idx
|
||||
return (is_bwd, mubatch_idx + stage_idx, stage_idx)
|
||||
tasks.sort(key=task_key)
|
||||
|
||||
### State
|
||||
# fwd_input : (mubatch_idx, stage_idx) -> input/activation
|
||||
# TODO: Actually slice the input data into separate microbatches
|
||||
fwd_input = {
|
||||
(mubatch_idx, 0): xs
|
||||
for mubatch_idx in range(num_mubatches)
|
||||
}
|
||||
# bwd_input : (mubatch_idx, stage_idx) -> activation
|
||||
bwd_input = {
|
||||
(mubatch_idx, num_stages-1): _mpmd_constant(
|
||||
stages[-1], P(), shape=(), value=1.0)
|
||||
for mubatch_idx in range(num_mubatches)
|
||||
}
|
||||
# grads_by_stage : stage_idx -> grads
|
||||
grads_by_stage = []
|
||||
for stage_idx, stage in enumerate(stages):
|
||||
with profile_utils.annotate(f'grad-init{stage_idx}', color='red'):
|
||||
grads_by_stage.append(stage.grad_init(params_by_stage[stage_idx]))
|
||||
# loss : mubatch_idx -> loss
|
||||
# TODO: Add a leading mubatch dim to loss instead of making it a list
|
||||
loss = [None] * num_mubatches
|
||||
|
||||
def maybe_ys(stage_idx):
|
||||
if stage_idx == num_stages-1:
|
||||
return ys
|
||||
else:
|
||||
return _mpmd_constant(stages[stage_idx], P(), shape=(), value=jnp.nan)
|
||||
|
||||
### Microbatched forward+backward
|
||||
for mubatch_idx, stage_idx, is_bwd in tasks:
|
||||
stage = stages[stage_idx]
|
||||
params = params_by_stage[stage_idx]
|
||||
fwd_bwd_str = 'B' if is_bwd else 'F'
|
||||
with profile_utils.annotate(
|
||||
f'mub{mubatch_idx}/{fwd_bwd_str}{stage_idx}', color='cyan'
|
||||
):
|
||||
curr_id = (mubatch_idx, stage_idx)
|
||||
if not is_bwd:
|
||||
### Forward
|
||||
succ_id = (mubatch_idx, stage_idx+1)
|
||||
activation = stage.fwd(
|
||||
params,
|
||||
fwd_input[curr_id],
|
||||
maybe_ys(stage_idx),
|
||||
)
|
||||
if stage_idx+1 < num_stages:
|
||||
with profile_utils.annotate(
|
||||
f'Tx mub{mubatch_idx} to {stage_idx+1}', color='yellow',
|
||||
):
|
||||
fwd_input[succ_id] = transfer(
|
||||
activation,
|
||||
stages[stage_idx+1],
|
||||
P(),
|
||||
)
|
||||
else:
|
||||
loss[mubatch_idx] = activation
|
||||
else:
|
||||
### Backward
|
||||
succ_id = (mubatch_idx, stage_idx-1)
|
||||
grads_by_stage[stage_idx], activation = stage.bwd_and_grad_acc(
|
||||
params,
|
||||
fwd_input.pop(curr_id), # NB: Frees activation afterwards.
|
||||
maybe_ys(stage_idx),
|
||||
grads_by_stage[stage_idx],
|
||||
bwd_input.pop(curr_id), # NB: Frees activation afterwards.
|
||||
)
|
||||
if stage_idx-1 >= 0:
|
||||
with profile_utils.annotate(
|
||||
f'Tx mub{mubatch_idx} to {stage_idx-1}', color='yellow',
|
||||
):
|
||||
bwd_input[succ_id] = transfer(
|
||||
activation,
|
||||
stages[stage_idx-1],
|
||||
P(),
|
||||
)
|
||||
|
||||
### Update params
|
||||
for stage_idx, stage in enumerate(stages):
|
||||
with profile_utils.annotate(f'U{stage_idx}', color='green'):
|
||||
params_by_stage[stage_idx] = stage.update(
|
||||
params_by_stage[stage_idx],
|
||||
grads_by_stage[stage_idx],
|
||||
)
|
||||
|
||||
return loss, params_by_stage
|
||||
|
||||
|
||||
### Example usage
|
||||
|
||||
def example_pp(num_processes, process_id):
|
||||
assert jax.device_count() == 8
|
||||
|
||||
NUM_STAGES = 4
|
||||
NUM_MUBATCHES = 4
|
||||
|
||||
# FIXME: Support stages spread across multiple processes.
|
||||
assert NUM_STAGES % num_processes == 0
|
||||
|
||||
LAYER_SIZE = 8192
|
||||
# a) Several layers per stage, little communication (32MB activations)
|
||||
# NUM_LAYERS = NUM_STAGES * 16
|
||||
# BATCH_SIZE = 1024
|
||||
# b) One layer per stage, more communication (512MB activations)
|
||||
NUM_LAYERS = NUM_STAGES
|
||||
BATCH_SIZE = 1024 * 16
|
||||
|
||||
ENABLE_TP = True
|
||||
|
||||
|
||||
@jax.jit
|
||||
def mlp(params, xs):
|
||||
for WA, WB in params:
|
||||
xs = xs @ WA @ WB
|
||||
return xs
|
||||
|
||||
@jax.jit
|
||||
def mse(act, ys):
|
||||
return jnp.mean(jnp.square(act - ys))
|
||||
|
||||
def init_params(key):
|
||||
params = []
|
||||
for _ in range(NUM_LAYERS):
|
||||
key, key_WA, key_WB = jax.random.split(key, 3)
|
||||
WA = jax.random.normal(key_WA, (LAYER_SIZE, LAYER_SIZE))
|
||||
WB = jax.random.normal(key_WB, (LAYER_SIZE, LAYER_SIZE))
|
||||
params.append((WA, WB))
|
||||
return params, key
|
||||
|
||||
def shard_params_by_stage(params, stages):
|
||||
num_per_stage, rem = divmod(len(params), len(stages))
|
||||
assert num_per_stage > 0
|
||||
assert rem == 0
|
||||
params_by_stage = [
|
||||
jax.tree.map(
|
||||
lambda arr, spec: transfer(arr, stage, spec),
|
||||
params[num_per_stage*stage_idx:num_per_stage*(stage_idx+1)],
|
||||
stage.params_specs,
|
||||
)
|
||||
for stage_idx, stage in enumerate(stages)
|
||||
]
|
||||
return params_by_stage
|
||||
|
||||
|
||||
# Define stages -- two devices per stage (running fully-replicated).
|
||||
num_devices_per_stage = jax.device_count() // NUM_STAGES
|
||||
stages = []
|
||||
for i in range(NUM_STAGES):
|
||||
devices = jax.devices()[num_devices_per_stage*i : num_devices_per_stage*(i+1)]
|
||||
assert all(d.process_index == devices[0].process_index for d in devices)
|
||||
mesh = Mesh(np.asarray(devices), ('model',))
|
||||
if i == NUM_STAGES - 1:
|
||||
fwd = lambda params, xs, ys: mse(mlp(params, xs), ys)
|
||||
else:
|
||||
fwd = lambda params, xs, _ys: mlp(params, xs)
|
||||
num_layers_per_stage = NUM_LAYERS // NUM_STAGES
|
||||
if ENABLE_TP:
|
||||
params_specs = [(P(None, 'model'), P('model', None))] * num_layers_per_stage
|
||||
else:
|
||||
params_specs = [(P(), P())] * num_layers_per_stage
|
||||
stages.append(Stage(fwd, mesh, params_specs))
|
||||
|
||||
def step_fn(params_by_stage, xs, ys):
|
||||
return stages_step_fn(stages, NUM_MUBATCHES, params_by_stage, xs, ys)
|
||||
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
params, key = init_params(key)
|
||||
params_by_stage = shard_params_by_stage(params, stages)
|
||||
|
||||
# Just keep reusing one batch, so we don't have to worry about infeed.
|
||||
key, key_xs = jax.random.split(key)
|
||||
xs_batch = jax.random.uniform(
|
||||
key_xs,
|
||||
(BATCH_SIZE, LAYER_SIZE),
|
||||
)
|
||||
ys_batch = 7 * xs_batch
|
||||
|
||||
xs_batch = transfer(xs_batch, stages[0], P())
|
||||
ys_batch = transfer(ys_batch, stages[-1], P())
|
||||
|
||||
NUM_STEPS = 50
|
||||
NUM_STEPS_PROFILED = 3
|
||||
for i in range(NUM_STEPS):
|
||||
print(f'===== STEP {i} {process_id=} =====')
|
||||
if i == NUM_STEPS - NUM_STEPS_PROFILED:
|
||||
profile_utils.maybe_start_profile(f"pp_trace/p{process_id}")
|
||||
|
||||
with profile_utils.annotate(f'step{i}', color='white'):
|
||||
loss, params_by_stage = step_fn(params_by_stage, xs_batch, ys_batch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
num_processes = 4
|
||||
if len(sys.argv) >= 2:
|
||||
num_processes = int(sys.argv[1])
|
||||
success = launch_utils.launch_example(num_processes, example_pp)
|
||||
sys.exit(0 if success else 1)
|
106
jax/experimental/_private_mm/examples/example_tests.py
Normal file
106
jax/experimental/_private_mm/examples/example_tests.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Runs some simple mm operations on varying numbers of processes."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
|
||||
from jax.experimental import _private_mm as mm
|
||||
from jax.experimental._private_mm.examples import launch_utils
|
||||
|
||||
|
||||
def make_two_meshes():
|
||||
devices = jax.devices()
|
||||
num_devices = len(devices)
|
||||
assert num_devices % 2 == 0
|
||||
mesh1 = Mesh(devices[:num_devices//2], ('data',))
|
||||
mesh2 = Mesh(devices[num_devices//2:], ('data',))
|
||||
sharding1 = NamedSharding(mesh1, P('data'))
|
||||
sharding2 = NamedSharding(mesh2, P('data'))
|
||||
return mesh1, mesh2, sharding1, sharding2
|
||||
|
||||
|
||||
def test_device_put_uncommitted(_num_processes, process_id):
|
||||
_, _, sharding1, sharding2 = make_two_meshes()
|
||||
x = mm.device_put(jnp.ones((16,16)), sharding1)
|
||||
x.block_until_ready()
|
||||
|
||||
|
||||
def test_device_put_across_meshes(_num_processes, process_id):
|
||||
_, _, sharding1, sharding2 = make_two_meshes()
|
||||
x = mm.device_put(jnp.ones((16,16)), sharding1)
|
||||
y = mm.device_put(x, sharding2)
|
||||
if y.is_fully_remote:
|
||||
y.block_until_ready()
|
||||
else:
|
||||
np.testing.assert_array_equal(y.jax_array, jnp.ones((16,16)))
|
||||
|
||||
|
||||
def test_jit_and_transfer(_num_processes, process_id):
|
||||
_, _, sharding1, sharding2 = make_two_meshes()
|
||||
x1 = mm.device_put(jnp.ones((16,16)), sharding1)
|
||||
x2 = mm.jit(lambda x: x + 1, out_shardings=sharding1)(x1)
|
||||
y1 = mm.device_put(x2, sharding2)
|
||||
y2 = mm.jit(lambda x: x * 2, out_shardings=sharding2)(y1)
|
||||
if y2.is_fully_remote:
|
||||
y2.block_until_ready()
|
||||
else:
|
||||
np.testing.assert_array_equal(y2.jax_array, jnp.full((16,16), 4))
|
||||
|
||||
|
||||
def run_test(num_processes, test_fun, name):
|
||||
print(f' - {name} ... ', end='', flush=True)
|
||||
success = launch_utils.launch_example(num_processes, test_fun)
|
||||
if success:
|
||||
print('OK')
|
||||
else:
|
||||
print('FAIL')
|
||||
return success
|
||||
|
||||
|
||||
def run_tests():
|
||||
# For 1 process mm.device_puts simply reduce to jax.device_puts.
|
||||
# For 2 processes and tests involving two meshes we require NCCL comms,
|
||||
# but all devices of a mesh are managed by the same process.
|
||||
# For 4 processes and tests involving two meshes we additionally have to
|
||||
# deal with devices of a mesh being managed by multipel processes.
|
||||
# (The latter currently doesn't work.)
|
||||
NUM_PROCESSESS = (1, 2, 4)
|
||||
TESTS = [
|
||||
('device_put_uncommitted', test_device_put_uncommitted),
|
||||
('device_put_across_meshes', test_device_put_across_meshes),
|
||||
('jit_and_transfer', test_jit_and_transfer),
|
||||
]
|
||||
num_failures = 0
|
||||
for num_processes in NUM_PROCESSESS:
|
||||
print(f'=== {num_processes=} ===')
|
||||
for test_name, test_fun in TESTS:
|
||||
success = run_test(num_processes, test_fun, test_name)
|
||||
if not success:
|
||||
num_failures += 1
|
||||
if num_failures == 0:
|
||||
print('All tests succeeded!')
|
||||
return True
|
||||
else:
|
||||
print(f'{num_failures} tests failed!')
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
success = run_tests()
|
||||
sys.exit(0 if success else 1)
|
95
jax/experimental/_private_mm/examples/launch_utils.py
Normal file
95
jax/experimental/_private_mm/examples/launch_utils.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities to launch multi-process JAX examples on a single host."""
|
||||
|
||||
from functools import partial
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def init_multi_process_local(num_processes, process_id, num_devices):
|
||||
assert 0 <= process_id < num_processes
|
||||
|
||||
# Assume all processes run on a single node
|
||||
assert num_devices % num_processes == 0
|
||||
num_devices_per_process = num_devices // num_processes
|
||||
local_device_ids = [
|
||||
process_id*num_devices_per_process + i
|
||||
for i in range(num_devices_per_process)
|
||||
]
|
||||
|
||||
import jax
|
||||
jax.distributed.initialize(
|
||||
coordinator_address="localhost:1234",
|
||||
num_processes=num_processes,
|
||||
process_id=process_id,
|
||||
local_device_ids=local_device_ids,
|
||||
)
|
||||
|
||||
|
||||
# Needs to be a top-level function to pickle as part of multiprocessing.
|
||||
def _entrypoint(num_processes, process_id, user_main, num_devices):
|
||||
# Only import these in the subprocess, not the launcher process.
|
||||
import jax.experimental.multihost_utils
|
||||
from jax.experimental._private_mm import profile_utils
|
||||
|
||||
init_multi_process_local(num_processes, process_id, num_devices)
|
||||
jax.experimental.multihost_utils.sync_global_devices("start_user_main")
|
||||
user_main(num_processes, process_id)
|
||||
profile_utils.maybe_stop_profile()
|
||||
|
||||
|
||||
def launch_example(num_processes, user_main, num_devices=8):
|
||||
"""
|
||||
A launcher for examples running across multiple processes on a single node.
|
||||
Returns true iff all processes exited successfully.
|
||||
|
||||
Example code my_example.py:
|
||||
def my_example(num_processes, process_id):
|
||||
# Do some distributed JAX stuff.
|
||||
...
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
num_processes = int(sys.argv[1])
|
||||
launch_utils.launch_example(num_processes, my_example)
|
||||
|
||||
Usage:
|
||||
# Run without profiling
|
||||
my_example.py 4
|
||||
# Run with jax.profiler + annotations
|
||||
PROFILE=jax python3 my_example.py 4
|
||||
# Run with nsys profiling + annotations
|
||||
PROFILE=nsys nsys profile --output my_example.nsys-rep --cpuctxsw=none --trace=cublas,cuda,cudnn,cusolver,nvtx,osrt,python-gil --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop --cuda-graph-trace=node --python-sampling=true python3 my_example.py 4
|
||||
"""
|
||||
assert num_processes > 0
|
||||
# Spawn subprocesses to avoid timeouts when profiling using nsys.
|
||||
ctx = multiprocessing.get_context('spawn')
|
||||
ps = [
|
||||
ctx.Process(
|
||||
target=partial(
|
||||
_entrypoint,
|
||||
num_processes,
|
||||
process_id,
|
||||
user_main,
|
||||
num_devices,
|
||||
),
|
||||
name=f'example_proc{process_id}',
|
||||
)
|
||||
for process_id in range(num_processes)
|
||||
]
|
||||
for p in ps:
|
||||
p.start()
|
||||
for p in ps:
|
||||
p.join()
|
||||
return all(p.exitcode == 0 for p in ps)
|
285
jax/experimental/_private_mm/mini_dime.py
Normal file
285
jax/experimental/_private_mm/mini_dime.py
Normal file
@ -0,0 +1,285 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Explicit NCCL communicators (via CuPy) integrated with JAX. Based on the
|
||||
communication library used in JaxPP (https://arxiv.org/abs/2412.14374).
|
||||
Requires `pip install cupy-cuda12x`.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
from functools import cached_property
|
||||
|
||||
try:
|
||||
import cupy # type: ignore[import-not-found]
|
||||
from cupy.cuda import nccl # type: ignore[import-not-found]
|
||||
|
||||
# CuPy NCCL utils from https://github.com/cupy/cupy/blob/118ade4a146d1cc68519f7f661f2c145f0b942c9/cupyx/distributed/_nccl_comm.py#L46-L55
|
||||
_nccl_dtypes = {
|
||||
"b": nccl.NCCL_INT8,
|
||||
"B": nccl.NCCL_UINT8,
|
||||
"i": nccl.NCCL_INT32,
|
||||
"I": nccl.NCCL_UINT32,
|
||||
"l": nccl.NCCL_INT64,
|
||||
"L": nccl.NCCL_UINT64,
|
||||
"q": nccl.NCCL_INT64,
|
||||
"Q": nccl.NCCL_UINT64,
|
||||
"e": nccl.NCCL_FLOAT16,
|
||||
"f": nccl.NCCL_FLOAT32,
|
||||
"d": nccl.NCCL_FLOAT64,
|
||||
# Size of array will be doubled
|
||||
"F": nccl.NCCL_FLOAT32,
|
||||
"D": nccl.NCCL_FLOAT64,
|
||||
}
|
||||
except ImportError:
|
||||
cupy = None
|
||||
nccl = None
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib.xla_extension as xe
|
||||
from jax._src import array
|
||||
from jax._src.op_shardings import are_op_shardings_equal
|
||||
|
||||
|
||||
def _get_nccl_dtype_and_count(arr, count=None):
|
||||
dtype = arr.dtype.char
|
||||
if dtype not in _nccl_dtypes:
|
||||
raise TypeError(f"Unknown dtype {arr.dtype} for NCCL")
|
||||
nccl_dtype = _nccl_dtypes[dtype]
|
||||
if count is None:
|
||||
count = arr.size
|
||||
if dtype in "FD":
|
||||
return nccl_dtype, 2 * count
|
||||
return nccl_dtype, count
|
||||
|
||||
|
||||
def get_distributed_client() -> xe.DistributedRuntimeClient:
|
||||
from jax._src.distributed import global_state
|
||||
|
||||
assert isinstance(global_state.client, xe.DistributedRuntimeClient)
|
||||
return global_state.client
|
||||
|
||||
|
||||
class UniqueDevices(tuple[jax.Device, ...]):
|
||||
def __new__(cls, *args):
|
||||
return super().__new__(cls, sorted(set(args), key=lambda d: d.id))
|
||||
|
||||
@cached_property
|
||||
def ranks(self):
|
||||
return OrderedDict((d, idx) for idx, d in enumerate(self))
|
||||
|
||||
@property
|
||||
def leader(self):
|
||||
return self[0]
|
||||
|
||||
@cached_property
|
||||
def key(self) -> str:
|
||||
return ",".join(str(d.id) for d in self)
|
||||
|
||||
|
||||
local_comms: dict = {}
|
||||
|
||||
|
||||
def get_or_create_comm(devs: UniqueDevices):
|
||||
TIMEOUT = 5_000
|
||||
|
||||
comm = local_comms.get(devs)
|
||||
my_process_index = jax.process_index()
|
||||
if comm is None:
|
||||
if devs.leader.process_index == my_process_index:
|
||||
nccl_id = nccl.get_unique_id()
|
||||
get_distributed_client().key_value_set_bytes(
|
||||
devs.key, pickle.dumps(nccl_id)
|
||||
)
|
||||
else:
|
||||
nccl_id = get_distributed_client().blocking_key_value_get_bytes(
|
||||
devs.key, TIMEOUT
|
||||
)
|
||||
nccl_id = pickle.loads(nccl_id)
|
||||
|
||||
nccl.groupStart()
|
||||
for d in devs:
|
||||
if d.process_index == my_process_index:
|
||||
with cupy.cuda.Device(d.local_hardware_id):
|
||||
comm = nccl.NcclCommunicator(len(devs), nccl_id, devs.ranks[d])
|
||||
nccl.groupEnd()
|
||||
|
||||
local_comms[devs] = comm
|
||||
return comm
|
||||
|
||||
|
||||
local_streams: dict = {}
|
||||
|
||||
|
||||
class OpT(enum.Enum):
|
||||
SEND = 0
|
||||
RECV = 1
|
||||
|
||||
|
||||
def get_or_create_stream(op: OpT, local_device: jax.Device):
|
||||
# XXX: I think this can be one stream per local_device for this specific example.
|
||||
# It depends on the use case
|
||||
stream = local_streams.get((op, local_device))
|
||||
if stream is None:
|
||||
with cupy.cuda.Device(local_device.local_hardware_id):
|
||||
stream = cupy.cuda.Stream(non_blocking=True)
|
||||
local_streams[local_device] = stream
|
||||
return stream
|
||||
|
||||
|
||||
def shardings_are_compatible(
|
||||
self: jax.sharding.Sharding, other: jax.sharding.Sharding, ndim: int
|
||||
):
|
||||
# NOTE: Variant of `jax.sharding.Sharding.is_equivalent_to` that skips _internal_device_list check
|
||||
return (
|
||||
are_op_shardings_equal(
|
||||
self._to_xla_hlo_sharding(ndim), other._to_xla_hlo_sharding(ndim)
|
||||
)
|
||||
# and self._internal_device_list == other._internal_device_list # type: ignore
|
||||
and self.memory_kind == other.memory_kind
|
||||
)
|
||||
|
||||
|
||||
## API
|
||||
|
||||
|
||||
def send_or_recv(
|
||||
x: jax.Array,
|
||||
tgt_sharding: jax.sharding.Sharding,
|
||||
src_sharding: jax.sharding.Sharding | None = None,
|
||||
):
|
||||
"""
|
||||
When `src_sharding is None` this function corresponds to a send and
|
||||
`x.sharding` must be equal to `tgt_sharding`.
|
||||
When `src_sharding is not None` this function corresponds to a receive
|
||||
and `x` will be consumed, i.e. it's unsafe to use `x` after `send_or_recv(x, src_sharding=...)`.
|
||||
|
||||
`x` can be a "global" array spanning multiple processes/hosts.
|
||||
In that case, this process will send/receive only its corresponding addressable_shards
|
||||
"""
|
||||
|
||||
if src_sharding is None:
|
||||
is_send = True
|
||||
other_sharding = tgt_sharding
|
||||
else:
|
||||
is_send = False
|
||||
other_sharding = src_sharding
|
||||
|
||||
if not is_send:
|
||||
# XXX: x.sharding and tgt_sharding must be equal since this is a recv.
|
||||
# This seems redundant to me. Not sure what the final version from Skye
|
||||
# will look like.
|
||||
assert x.sharding == tgt_sharding
|
||||
|
||||
# TODO: implement reshard for 4 devs -> 2 devs or 2->4 reshards
|
||||
assert shardings_are_compatible(x.sharding, other_sharding, x.ndim), \
|
||||
f'incompatible shardings: {x.sharding=} vs {other_sharding=}'
|
||||
|
||||
# Create communicators lazily as needed. This can be a separate "setup function"
|
||||
for pair in zip(
|
||||
x.sharding._device_assignment,
|
||||
other_sharding._device_assignment,
|
||||
strict=True,
|
||||
):
|
||||
if pair[0].process_index == jax.process_index():
|
||||
_ = get_or_create_comm(UniqueDevices(*pair))
|
||||
|
||||
shards_by_device = {shard.device: shard for shard in x.addressable_shards}
|
||||
|
||||
cpy_arrays_and_streams = []
|
||||
# FIXME: maybe narrow `nccl_group_{start,end}` scope by first accumulating
|
||||
# arguments in a list and then performing the operation
|
||||
nccl.groupStart()
|
||||
for x_device, other_device in zip(
|
||||
x.sharding._device_assignment,
|
||||
other_sharding._device_assignment,
|
||||
strict=True,
|
||||
):
|
||||
if x_device.process_index == jax.process_index():
|
||||
shard = shards_by_device[x_device]
|
||||
stream = get_or_create_stream(OpT.SEND if is_send else OpT.RECV, x_device)
|
||||
# FIXME: cupy doesn't support bf16. Use capsule/ctype APIs
|
||||
cpy_arr = cupy.from_dlpack(
|
||||
jax.dlpack.to_dlpack(shard.data, stream=stream.ptr)
|
||||
)
|
||||
cpy_arrays_and_streams.append((cpy_arr, stream))
|
||||
|
||||
nccl_dtype, count = _get_nccl_dtype_and_count(cpy_arr)
|
||||
|
||||
key = UniqueDevices(x_device, other_device)
|
||||
comm = get_or_create_comm(key)
|
||||
|
||||
with cupy.cuda.Device(x_device.local_hardware_id):
|
||||
op = comm.send if is_send else comm.recv
|
||||
op(
|
||||
cpy_arr.data.ptr,
|
||||
count,
|
||||
nccl_dtype,
|
||||
key.ranks[other_device],
|
||||
stream.ptr,
|
||||
)
|
||||
|
||||
nccl.groupEnd()
|
||||
# NOTE: since communicators are blocking, after the group_end operation
|
||||
# above, all the send/recvs have been enqueued into the stream. Therefore,
|
||||
# we can record events on the stream
|
||||
|
||||
# XXX: I don't like different return types below, however I am not sure
|
||||
# what's a better alternative given we want a "symmetric"
|
||||
# `send_or_recv` API
|
||||
if is_send:
|
||||
|
||||
def wait():
|
||||
for _, stream in cpy_arrays_and_streams:
|
||||
stream.synchronize()
|
||||
# NOTE: Keep the objects below alive just in case they are not
|
||||
# deleted/overwritten by XLA while in use
|
||||
return (x, cpy_arrays_and_streams)
|
||||
|
||||
return wait
|
||||
else:
|
||||
|
||||
def enqueue_wait():
|
||||
jax_single_arrays = []
|
||||
for x_device, (cpy_arr, stream) in zip(
|
||||
x.sharding._device_assignment, cpy_arrays_and_streams, strict=True
|
||||
):
|
||||
with cupy.cuda.Device(x_device.local_hardware_id):
|
||||
event = stream.record()
|
||||
ready_events_stream = (
|
||||
x_device.get_stream_for_external_ready_events()
|
||||
)
|
||||
cupy.cuda.ExternalStream(ready_events_stream).wait_event(event)
|
||||
jax_sda = jnp.array(
|
||||
jax._src.lib.xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
cpy_arr.toDlpack(),
|
||||
x_device,
|
||||
ready_events_stream,
|
||||
),
|
||||
copy=True, # XXX: Just to be safe
|
||||
)
|
||||
jax_single_arrays.append(jax_sda)
|
||||
return array.ArrayImpl(
|
||||
x.aval,
|
||||
x.sharding,
|
||||
jax_single_arrays,
|
||||
committed=True,
|
||||
# NOTE: _skip_checks can be set to True however since this happens
|
||||
# asynchronously there's no perf harm to keep it False.
|
||||
_skip_checks=False,
|
||||
)
|
||||
|
||||
return enqueue_wait
|
300
jax/experimental/_private_mm/mm.py
Normal file
300
jax/experimental/_private_mm/mm.py
Normal file
@ -0,0 +1,300 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Poor-man's MPMD for JAX."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, lru_cache, partial, wraps
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, Sharding, SingleDeviceSharding
|
||||
|
||||
from jax._src.tree_util import broadcast_prefix, prefix_errors, tree_leaves_with_path
|
||||
|
||||
from jax.experimental._private_mm import mini_dime
|
||||
|
||||
|
||||
@dataclass
|
||||
class MpmdArray:
|
||||
"""A generalization of jax.Array that also supports fully remote arrays."""
|
||||
aval: jax.core.ShapedArray
|
||||
sharding: Sharding
|
||||
_complete: Callable[[], jax.Array | tuple] | None
|
||||
_result: jax.Array | tuple | None = None
|
||||
|
||||
def __repr__(self):
|
||||
remote_str = ', fully-remote' if self.is_fully_remote else ''
|
||||
return (
|
||||
f'MpmdArray({self.aval}, sharding={self.sharding}, '
|
||||
f'devices={self.sharding.mesh.devices}{remote_str})'
|
||||
)
|
||||
|
||||
def block_until_ready(self):
|
||||
if self._complete is None:
|
||||
# Already awaited.
|
||||
assert self._result is not None
|
||||
return
|
||||
result = self._complete()
|
||||
if isinstance(result, jax.Array):
|
||||
# Recv result, store array.
|
||||
self._result = result
|
||||
else:
|
||||
# No-op result or send result. Drop objects kept alive, but register
|
||||
# completion.
|
||||
self._result = ()
|
||||
# Drop the closure.
|
||||
self._complete = None
|
||||
return self
|
||||
|
||||
@cached_property
|
||||
def is_fully_remote(self):
|
||||
return is_fully_remote_sharding(self.sharding)
|
||||
|
||||
@property
|
||||
def jax_array(self):
|
||||
if self.is_fully_remote:
|
||||
raise ValueError('cannot convert fully-remote MpmdArray to jax.Array')
|
||||
self.block_until_ready()
|
||||
assert isinstance(self._result, jax.Array), (
|
||||
'expected non-fully-remote MpmdArray to hold some local data, but got: '
|
||||
f'{self._result} (mesh devices: {self.sharding.mesh.devices})'
|
||||
)
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.aval.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.aval.dtype
|
||||
|
||||
|
||||
JaxOrMpmdArray = jax.Array | MpmdArray
|
||||
|
||||
|
||||
def is_local_device(device) -> bool:
|
||||
return device.process_index == jax.process_index()
|
||||
|
||||
|
||||
def is_fully_remote_sharding(sharding: Sharding) -> bool:
|
||||
# TODO: Handle shardings other than NamedSharding?
|
||||
assert isinstance(sharding, NamedSharding)
|
||||
return not any(map(is_local_device, sharding.mesh.devices.flat))
|
||||
|
||||
|
||||
def is_fully_local_sharding(sharding: Sharding) -> bool:
|
||||
# TODO: Handle shardings other than NamedSharding?
|
||||
assert isinstance(sharding, NamedSharding)
|
||||
return all(map(is_local_device, sharding.mesh.devices.flat))
|
||||
|
||||
|
||||
def is_fully_remote_array(arr: JaxOrMpmdArray) -> bool:
|
||||
return isinstance(arr, MpmdArray) and arr.is_fully_remote
|
||||
|
||||
|
||||
def as_jax_array(arr: JaxOrMpmdArray) -> jax.Array:
|
||||
if isinstance(arr, MpmdArray):
|
||||
return arr.jax_array
|
||||
assert isinstance(arr, jax.Array)
|
||||
return arr
|
||||
|
||||
|
||||
def fix_sharding(sharding: Sharding) -> Sharding:
|
||||
# FIXME: During jax.device_put(..., sharding) jaxlib/XLA fills in a memory
|
||||
# kind if none was explicitly given. We don't always call into
|
||||
# jax.device_put here, but we want to mirror this behavior so that even
|
||||
# processes that don't call jax.device_put end up with the exact same
|
||||
# metadata. (The bandaid below is likely incomplete.)
|
||||
if sharding.memory_kind is None:
|
||||
sharding = sharding.with_memory_kind('device')
|
||||
return sharding
|
||||
|
||||
|
||||
@lru_cache
|
||||
def recv_buf_factory(shape, dtype, tgt_sharding):
|
||||
@partial(jax.jit, out_shardings=tgt_sharding)
|
||||
def recv_buf_init():
|
||||
return jnp.zeros(shape, dtype)
|
||||
return recv_buf_init
|
||||
|
||||
|
||||
# TODO: Generalize mm.device_put to mix jax.device_put, send and recv as
|
||||
# needed. For the moment, we only allow cases that neatly fall into one of the
|
||||
# above three cases, i.e. the present process either issue a jax.device_put,
|
||||
# a NCCL send or a NCCL recv. This means that every submesh (e.g. a stage) needs
|
||||
# to be managed by a single process for now.
|
||||
def device_put(arr: JaxOrMpmdArray, device: Sharding) -> MpmdArray:
|
||||
assert isinstance(device, Sharding)
|
||||
tgt_sharding = fix_sharding(device)
|
||||
src_sharding = fix_sharding(arr.sharding)
|
||||
|
||||
def complete_with(complete):
|
||||
return MpmdArray(
|
||||
aval=arr.aval,
|
||||
sharding=tgt_sharding,
|
||||
_complete=complete,
|
||||
)
|
||||
|
||||
if is_fully_remote_array(arr):
|
||||
if is_fully_remote_sharding(tgt_sharding):
|
||||
# FullyRemote->FullyRemote: Nothing to be done.
|
||||
return complete_with(lambda: ())
|
||||
else:
|
||||
# FullyRemote->NonFullyRemote: Recv.
|
||||
# NOTE: We run the same jitted fun on each participating device,
|
||||
# rather than jax.device_put(jnp.zeros(...), tgt_sharding). The
|
||||
# latter produces jnp.zeros first on one local device and then P2P-
|
||||
# copies to the others, which anecdotally appears to be slower, but
|
||||
# also litters the profile, so we avoid it.
|
||||
recv_buf = recv_buf_factory(
|
||||
arr.aval.shape,
|
||||
arr.aval.dtype,
|
||||
tgt_sharding,
|
||||
)()
|
||||
return complete_with(
|
||||
mini_dime.send_or_recv(
|
||||
recv_buf,
|
||||
tgt_sharding,
|
||||
src_sharding,
|
||||
)
|
||||
)
|
||||
|
||||
# arr has some locally-addressable shards.
|
||||
jax_array = as_jax_array(arr)
|
||||
if jax_array.committed:
|
||||
if is_fully_remote_sharding(tgt_sharding):
|
||||
# NonFullyRemote->FullyRemote: Send.
|
||||
# FIXME: Should force completion at some point.
|
||||
return complete_with(
|
||||
mini_dime.send_or_recv(
|
||||
jax_array,
|
||||
tgt_sharding,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
is_fully_local_sharding(src_sharding) and
|
||||
is_fully_local_sharding(tgt_sharding)
|
||||
):
|
||||
# NonFullyRemote->NonFullyRemote: jax.device_put
|
||||
new_jax_array = jax.device_put(jax_array, tgt_sharding)
|
||||
return complete_with(lambda: new_jax_array)
|
||||
else:
|
||||
# NOTE: We exclude cases of NonFullyRemote -> NonFullyRemote
|
||||
# which would require a mix of jax.device_put, Send and Recv.
|
||||
raise NotImplementedError('unsupported transfer')
|
||||
else:
|
||||
# Uncommitted array.
|
||||
assert isinstance(jax_array.sharding, SingleDeviceSharding)
|
||||
if is_fully_remote_sharding(tgt_sharding):
|
||||
# Uncommitted->FullyRemote: Nothing to be done
|
||||
return complete_with(lambda: ())
|
||||
else:
|
||||
# Uncommitted->NonFullyRemote: jax.device_put
|
||||
# NOTE: Uncommitted arrays arise when the user hasn't yet specified
|
||||
# a device or sharding, so the current (single-device) sharding is
|
||||
# somewhat arbitrary.
|
||||
# An important assumption here is that, though said device will vary
|
||||
# from process to process, we expect all of the processes to have
|
||||
# the same values.
|
||||
#
|
||||
# Now we'd like to do something like
|
||||
# new_jax_array = jax.device_put(jax_array, tgt_sharding)
|
||||
# where we'd expect jax.device_put to simply simply transfer from
|
||||
# the current local single device to all the other relevant local
|
||||
# devices.
|
||||
#
|
||||
# This unfortunately doesn't work, because jax.device_put will check
|
||||
# the above assumption of same-values-everywhere by introducing a
|
||||
# broadcast from process 0 to all others. But in an MPMD program
|
||||
# only a subset of processes will participate in any given
|
||||
# device_put, so this might lead to hangs!
|
||||
#
|
||||
# We could likely work around this by doing appropriate device_puts
|
||||
# with single-device shardings and subsequently using
|
||||
# jax.make_array_from_single_device_arrays to build a global array.
|
||||
if not is_fully_local_sharding(tgt_sharding):
|
||||
raise NotImplementedError('unsupported transfer')
|
||||
new_jax_array = jax.device_put(jax_array, tgt_sharding)
|
||||
return complete_with(lambda: new_jax_array)
|
||||
|
||||
|
||||
def jit(*args, **kwargs):
|
||||
if (out_shardings := kwargs.get('out_shardings')) is None:
|
||||
raise ValueError('missing out_shardings')
|
||||
fun = jax.jit(*args, **kwargs)
|
||||
|
||||
@wraps(fun)
|
||||
def wrapped(*in_vals):
|
||||
first_fully_remote_input = next(
|
||||
(
|
||||
(path, in_val)
|
||||
for path, in_val in tree_leaves_with_path(in_vals)
|
||||
if is_fully_remote_array(in_val)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# This computation does not concern us, return fully-remote arrays.
|
||||
if first_fully_remote_input is not None:
|
||||
out_shape_dtypes = jax.eval_shape(fun, *in_vals)
|
||||
# Allow out_shardings to be a prefix tree
|
||||
try:
|
||||
out_shardings_flat = broadcast_prefix(
|
||||
out_shardings,
|
||||
out_shape_dtypes,
|
||||
is_leaf=lambda x: x is None, # FIXME: Correct?
|
||||
)
|
||||
except ValueError:
|
||||
e, *_ = prefix_errors(out_shardings, out_shape_dtypes)
|
||||
raise e('mm.jit out_shardings') from None
|
||||
out_shardings_full = jax.tree.unflatten(
|
||||
jax.tree.structure(out_shape_dtypes),
|
||||
out_shardings_flat,
|
||||
)
|
||||
# Make an MpmdArray for every out value
|
||||
def make_fully_remote_output(shape_dtype, sharding):
|
||||
if not is_fully_remote_sharding(sharding):
|
||||
path, in_val = first_fully_remote_input
|
||||
raise ValueError(
|
||||
'mm.jit produces a non-fully-remote output, but '
|
||||
f'was invoked on fully-remote input: {in_val} @ {path}')
|
||||
return MpmdArray(
|
||||
aval=jax.core.ShapedArray(
|
||||
shape_dtype.shape,
|
||||
shape_dtype.dtype,
|
||||
),
|
||||
sharding=sharding,
|
||||
_complete=lambda: (),
|
||||
)
|
||||
return jax.tree.map(
|
||||
make_fully_remote_output,
|
||||
out_shape_dtypes,
|
||||
out_shardings_full,
|
||||
)
|
||||
|
||||
# This computations concerns us, run the jax.jit-ed function.
|
||||
in_vals = jax.tree.map(as_jax_array, in_vals)
|
||||
out_vals = fun(*in_vals)
|
||||
return jax.tree.map(
|
||||
lambda jax_array: MpmdArray(
|
||||
jax_array.aval,
|
||||
jax_array.sharding,
|
||||
lambda: jax_array,
|
||||
),
|
||||
out_vals,
|
||||
)
|
||||
return wrapped
|
77
jax/experimental/_private_mm/profile_utils.py
Normal file
77
jax/experimental/_private_mm/profile_utils.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for profiling, abstracting over jax.profiler and nsys."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
|
||||
import jax
|
||||
|
||||
|
||||
def get_profiling_mode() -> str | None:
|
||||
mode = os.environ.get('PROFILE')
|
||||
if mode is not None:
|
||||
mode = mode.lower()
|
||||
assert mode in ('jax', 'nsys')
|
||||
return mode
|
||||
return None
|
||||
|
||||
|
||||
if get_profiling_mode() == 'nsys':
|
||||
from ctypes import cdll
|
||||
libcudart = cdll.LoadLibrary('libcudart.so')
|
||||
import nvtx # type: ignore[import-not-found]
|
||||
|
||||
|
||||
def maybe_start_profile(path):
|
||||
profiling_mode = get_profiling_mode()
|
||||
if profiling_mode is None:
|
||||
pass
|
||||
elif profiling_mode == 'jax':
|
||||
jax.profiler.start_trace(path)
|
||||
elif profiling_mode == 'nsys':
|
||||
libcudart.cudaProfilerStart()
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
def maybe_stop_profile():
|
||||
profiling_mode = get_profiling_mode()
|
||||
if profiling_mode is None:
|
||||
pass
|
||||
elif profiling_mode == 'jax':
|
||||
try:
|
||||
jax.profiler.stop_trace()
|
||||
except RuntimeError as e:
|
||||
if e.args[0] != 'No profile started':
|
||||
raise
|
||||
elif profiling_mode == 'nsys':
|
||||
libcudart.cudaProfilerStop()
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def annotate(label, color=None):
|
||||
profiling_mode = get_profiling_mode()
|
||||
if profiling_mode is None:
|
||||
yield
|
||||
elif profiling_mode == 'jax':
|
||||
with jax.profiler.TraceAnnotation(label):
|
||||
yield
|
||||
elif profiling_mode == 'nsys':
|
||||
with nvtx.annotate(label, color=color or 'red'):
|
||||
yield
|
||||
else:
|
||||
assert False
|
Loading…
x
Reference in New Issue
Block a user