mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
286 lines
9.5 KiB
Python
286 lines
9.5 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Explicit NCCL communicators (via CuPy) integrated with JAX. Based on the
|
|
communication library used in JaxPP (https://arxiv.org/abs/2412.14374).
|
|
Requires `pip install cupy-cuda12x`.
|
|
"""
|
|
|
|
import enum
|
|
import pickle
|
|
from collections import OrderedDict
|
|
from functools import cached_property
|
|
|
|
try:
|
|
import cupy # type: ignore[import-not-found]
|
|
from cupy.cuda import nccl # type: ignore[import-not-found]
|
|
|
|
# CuPy NCCL utils from https://github.com/cupy/cupy/blob/118ade4a146d1cc68519f7f661f2c145f0b942c9/cupyx/distributed/_nccl_comm.py#L46-L55
|
|
_nccl_dtypes = {
|
|
"b": nccl.NCCL_INT8,
|
|
"B": nccl.NCCL_UINT8,
|
|
"i": nccl.NCCL_INT32,
|
|
"I": nccl.NCCL_UINT32,
|
|
"l": nccl.NCCL_INT64,
|
|
"L": nccl.NCCL_UINT64,
|
|
"q": nccl.NCCL_INT64,
|
|
"Q": nccl.NCCL_UINT64,
|
|
"e": nccl.NCCL_FLOAT16,
|
|
"f": nccl.NCCL_FLOAT32,
|
|
"d": nccl.NCCL_FLOAT64,
|
|
# Size of array will be doubled
|
|
"F": nccl.NCCL_FLOAT32,
|
|
"D": nccl.NCCL_FLOAT64,
|
|
}
|
|
except ImportError:
|
|
cupy = None
|
|
nccl = None
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jaxlib.xla_extension as xe
|
|
from jax._src import array
|
|
from jax._src.op_shardings import are_op_shardings_equal
|
|
|
|
|
|
def _get_nccl_dtype_and_count(arr, count=None):
|
|
dtype = arr.dtype.char
|
|
if dtype not in _nccl_dtypes:
|
|
raise TypeError(f"Unknown dtype {arr.dtype} for NCCL")
|
|
nccl_dtype = _nccl_dtypes[dtype]
|
|
if count is None:
|
|
count = arr.size
|
|
if dtype in "FD":
|
|
return nccl_dtype, 2 * count
|
|
return nccl_dtype, count
|
|
|
|
|
|
def get_distributed_client() -> xe.DistributedRuntimeClient:
|
|
from jax._src.distributed import global_state
|
|
|
|
assert isinstance(global_state.client, xe.DistributedRuntimeClient)
|
|
return global_state.client
|
|
|
|
|
|
class UniqueDevices(tuple[jax.Device, ...]):
|
|
def __new__(cls, *args):
|
|
return super().__new__(cls, sorted(set(args), key=lambda d: d.id))
|
|
|
|
@cached_property
|
|
def ranks(self):
|
|
return OrderedDict((d, idx) for idx, d in enumerate(self))
|
|
|
|
@property
|
|
def leader(self):
|
|
return self[0]
|
|
|
|
@cached_property
|
|
def key(self) -> str:
|
|
return ",".join(str(d.id) for d in self)
|
|
|
|
|
|
local_comms: dict = {}
|
|
|
|
|
|
def get_or_create_comm(devs: UniqueDevices):
|
|
TIMEOUT = 5_000
|
|
|
|
comm = local_comms.get(devs)
|
|
my_process_index = jax.process_index()
|
|
if comm is None:
|
|
if devs.leader.process_index == my_process_index:
|
|
nccl_id = nccl.get_unique_id()
|
|
get_distributed_client().key_value_set_bytes(
|
|
devs.key, pickle.dumps(nccl_id)
|
|
)
|
|
else:
|
|
nccl_id = get_distributed_client().blocking_key_value_get_bytes(
|
|
devs.key, TIMEOUT
|
|
)
|
|
nccl_id = pickle.loads(nccl_id)
|
|
|
|
nccl.groupStart()
|
|
for d in devs:
|
|
if d.process_index == my_process_index:
|
|
with cupy.cuda.Device(d.local_hardware_id):
|
|
comm = nccl.NcclCommunicator(len(devs), nccl_id, devs.ranks[d])
|
|
nccl.groupEnd()
|
|
|
|
local_comms[devs] = comm
|
|
return comm
|
|
|
|
|
|
local_streams: dict = {}
|
|
|
|
|
|
class OpT(enum.Enum):
|
|
SEND = 0
|
|
RECV = 1
|
|
|
|
|
|
def get_or_create_stream(op: OpT, local_device: jax.Device):
|
|
# XXX: I think this can be one stream per local_device for this specific example.
|
|
# It depends on the use case
|
|
stream = local_streams.get((op, local_device))
|
|
if stream is None:
|
|
with cupy.cuda.Device(local_device.local_hardware_id):
|
|
stream = cupy.cuda.Stream(non_blocking=True)
|
|
local_streams[local_device] = stream
|
|
return stream
|
|
|
|
|
|
def shardings_are_compatible(
|
|
self: jax.sharding.Sharding, other: jax.sharding.Sharding, ndim: int
|
|
):
|
|
# NOTE: Variant of `jax.sharding.Sharding.is_equivalent_to` that skips _internal_device_list check
|
|
return (
|
|
are_op_shardings_equal(
|
|
self._to_xla_hlo_sharding(ndim), other._to_xla_hlo_sharding(ndim)
|
|
)
|
|
# and self._internal_device_list == other._internal_device_list # type: ignore
|
|
and self.memory_kind == other.memory_kind
|
|
)
|
|
|
|
|
|
## API
|
|
|
|
|
|
def send_or_recv(
|
|
x: jax.Array,
|
|
tgt_sharding: jax.sharding.Sharding,
|
|
src_sharding: jax.sharding.Sharding | None = None,
|
|
):
|
|
"""
|
|
When `src_sharding is None` this function corresponds to a send and
|
|
`x.sharding` must be equal to `tgt_sharding`.
|
|
When `src_sharding is not None` this function corresponds to a receive
|
|
and `x` will be consumed, i.e. it's unsafe to use `x` after `send_or_recv(x, src_sharding=...)`.
|
|
|
|
`x` can be a "global" array spanning multiple processes/hosts.
|
|
In that case, this process will send/receive only its corresponding addressable_shards
|
|
"""
|
|
|
|
if src_sharding is None:
|
|
is_send = True
|
|
other_sharding = tgt_sharding
|
|
else:
|
|
is_send = False
|
|
other_sharding = src_sharding
|
|
|
|
if not is_send:
|
|
# XXX: x.sharding and tgt_sharding must be equal since this is a recv.
|
|
# This seems redundant to me. Not sure what the final version from Skye
|
|
# will look like.
|
|
assert x.sharding == tgt_sharding
|
|
|
|
# TODO: implement reshard for 4 devs -> 2 devs or 2->4 reshards
|
|
assert shardings_are_compatible(x.sharding, other_sharding, x.ndim), \
|
|
f'incompatible shardings: {x.sharding=} vs {other_sharding=}'
|
|
|
|
# Create communicators lazily as needed. This can be a separate "setup function"
|
|
for pair in zip(
|
|
x.sharding._device_assignment,
|
|
other_sharding._device_assignment,
|
|
strict=True,
|
|
):
|
|
if pair[0].process_index == jax.process_index():
|
|
_ = get_or_create_comm(UniqueDevices(*pair))
|
|
|
|
shards_by_device = {shard.device: shard for shard in x.addressable_shards}
|
|
|
|
cpy_arrays_and_streams = []
|
|
# FIXME: maybe narrow `nccl_group_{start,end}` scope by first accumulating
|
|
# arguments in a list and then performing the operation
|
|
nccl.groupStart()
|
|
for x_device, other_device in zip(
|
|
x.sharding._device_assignment,
|
|
other_sharding._device_assignment,
|
|
strict=True,
|
|
):
|
|
if x_device.process_index == jax.process_index():
|
|
shard = shards_by_device[x_device]
|
|
stream = get_or_create_stream(OpT.SEND if is_send else OpT.RECV, x_device)
|
|
# FIXME: cupy doesn't support bf16. Use capsule/ctype APIs
|
|
cpy_arr = cupy.from_dlpack(
|
|
jax.dlpack.to_dlpack(shard.data, stream=stream.ptr)
|
|
)
|
|
cpy_arrays_and_streams.append((cpy_arr, stream))
|
|
|
|
nccl_dtype, count = _get_nccl_dtype_and_count(cpy_arr)
|
|
|
|
key = UniqueDevices(x_device, other_device)
|
|
comm = get_or_create_comm(key)
|
|
|
|
with cupy.cuda.Device(x_device.local_hardware_id):
|
|
op = comm.send if is_send else comm.recv
|
|
op(
|
|
cpy_arr.data.ptr,
|
|
count,
|
|
nccl_dtype,
|
|
key.ranks[other_device],
|
|
stream.ptr,
|
|
)
|
|
|
|
nccl.groupEnd()
|
|
# NOTE: since communicators are blocking, after the group_end operation
|
|
# above, all the send/recvs have been enqueued into the stream. Therefore,
|
|
# we can record events on the stream
|
|
|
|
# XXX: I don't like different return types below, however I am not sure
|
|
# what's a better alternative given we want a "symmetric"
|
|
# `send_or_recv` API
|
|
if is_send:
|
|
|
|
def wait():
|
|
for _, stream in cpy_arrays_and_streams:
|
|
stream.synchronize()
|
|
# NOTE: Keep the objects below alive just in case they are not
|
|
# deleted/overwritten by XLA while in use
|
|
return (x, cpy_arrays_and_streams)
|
|
|
|
return wait
|
|
else:
|
|
|
|
def enqueue_wait():
|
|
jax_single_arrays = []
|
|
for x_device, (cpy_arr, stream) in zip(
|
|
x.sharding._device_assignment, cpy_arrays_and_streams, strict=True
|
|
):
|
|
with cupy.cuda.Device(x_device.local_hardware_id):
|
|
event = stream.record()
|
|
ready_events_stream = (
|
|
x_device.get_stream_for_external_ready_events()
|
|
)
|
|
cupy.cuda.ExternalStream(ready_events_stream).wait_event(event)
|
|
jax_sda = jnp.array(
|
|
jax._src.lib.xla_client._xla.dlpack_managed_tensor_to_buffer(
|
|
cpy_arr.toDlpack(),
|
|
x_device,
|
|
ready_events_stream,
|
|
),
|
|
copy=True, # XXX: Just to be safe
|
|
)
|
|
jax_single_arrays.append(jax_sda)
|
|
return array.ArrayImpl(
|
|
x.aval,
|
|
x.sharding,
|
|
jax_single_arrays,
|
|
committed=True,
|
|
# NOTE: _skip_checks can be set to True however since this happens
|
|
# asynchronously there's no perf harm to keep it False.
|
|
_skip_checks=False,
|
|
)
|
|
|
|
return enqueue_wait
|