mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 06:46:07 +00:00
133 lines
4.8 KiB
Python
133 lines
4.8 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for synchronizing and communication across multiple hosts."""
|
|
|
|
import functools
|
|
from typing import Optional
|
|
import zlib
|
|
|
|
import jax
|
|
from jax.tree_util import PyTreeDef
|
|
from jax.experimental import maps
|
|
from jax.experimental.pjit import pjit, FROM_GDA
|
|
from jax.interpreters.pxla import PartitionSpec as P
|
|
from jax.experimental.global_device_array import GlobalDeviceArray
|
|
import numpy as np
|
|
|
|
|
|
# This needs to be top-level for the jax compilation cache.
|
|
@functools.partial(jax.pmap, axis_name='hosts')
|
|
def _psum(x: PyTreeDef) -> PyTreeDef:
|
|
return jax.lax.psum(x, 'hosts')
|
|
|
|
|
|
def broadcast_one_to_all(in_tree: PyTreeDef,
|
|
is_source: Optional[bool] = None) -> PyTreeDef:
|
|
"""Broadcast data from a source host (host 0 by default) to all other hosts.
|
|
|
|
Args:
|
|
in_tree: pytree of arrays - each array *must* have the same shape across the
|
|
hosts.
|
|
is_source: optional bool denoting whether the caller is the source. Only
|
|
'source host' will contribute the data for the broadcast. If None, then
|
|
host 0 is used.
|
|
|
|
Returns:
|
|
A pytree matching in_tree where the leaves now all contain the data from the
|
|
first host.
|
|
"""
|
|
if is_source is None:
|
|
is_source = jax.process_index() == 0
|
|
|
|
def pre_pmap(x):
|
|
if isinstance(x, GlobalDeviceArray):
|
|
raise ValueError('GDAs cannot be broadcasted from source host to other '
|
|
'hosts.')
|
|
if is_source:
|
|
return np.concatenate([
|
|
x[None, ...],
|
|
np.repeat([np.zeros_like(x)],
|
|
jax.local_device_count() - 1, 0)
|
|
])
|
|
else:
|
|
return np.repeat([np.zeros_like(x)], jax.local_device_count(), 0)
|
|
|
|
def post_pmap(x):
|
|
return jax.device_get(x)[0]
|
|
|
|
in_tree = jax.tree_map(pre_pmap, in_tree)
|
|
in_tree = jax.device_get(_psum(in_tree))
|
|
return jax.tree_map(post_pmap, in_tree)
|
|
|
|
|
|
def sync_global_devices(name: str):
|
|
"""Creates a barrier across all hosts/devices."""
|
|
h = np.int32(zlib.crc32(name.encode()))
|
|
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
|
|
|
|
|
|
def process_allgather(in_tree: PyTreeDef, tiled: bool = False) -> PyTreeDef:
|
|
"""Gather data from across processes.
|
|
|
|
Args:
|
|
in_tree: pytree of arrays - each array _must_ have the same shape across the
|
|
hosts.
|
|
tiled: Whether to stack or concat the output. Defaults to False i.e. stack
|
|
into a new positional axis at index 0.
|
|
This does not affect GDA inputs as the GDA output will always be
|
|
concatenated.
|
|
Scalar inputs will always be stacked.
|
|
|
|
Returns:
|
|
Pytress of arrays where the data is gathered from all hosts.
|
|
* If the input is a GDA, then the data is fully replicated.
|
|
* If the input is non-GDA, then the output shape is dependent on the
|
|
`tiled` argument. If its False, then the output will be stacked else
|
|
concatenated.
|
|
* If the input is non-GDA and scalar, then the output will be stacked.
|
|
"""
|
|
|
|
def _pjit(inp):
|
|
if isinstance(inp, GlobalDeviceArray):
|
|
if inp.is_fully_replicated:
|
|
return inp.local_data(0).to_py()
|
|
global_mesh = inp.mesh
|
|
in_axis_resources = FROM_GDA
|
|
else:
|
|
# DA/SDA/np.array will be sharded based on global_mesh.local_mesh.
|
|
# Shape of local_mesh will always be (1, local_device_count())
|
|
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
|
jax.local_device_count())
|
|
global_mesh = maps.Mesh(devices, ('processes', 'local_devices'))
|
|
in_axis_resources = P('processes')
|
|
if inp.ndim == 0 or not tiled:
|
|
inp = np.expand_dims(inp, axis=0)
|
|
|
|
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
|
out = pjit(lambda x: x, in_axis_resources=in_axis_resources,
|
|
out_axis_resources=None)(inp)
|
|
return out.local_data(0).to_py()
|
|
|
|
with jax._src.config.parallel_functions_output_gda(True):
|
|
return jax.tree_map(_pjit, in_tree)
|
|
|
|
|
|
def assert_equal(in_tree, fail_message: str = ''):
|
|
"""Verifies that all the hosts have the same tree of values."""
|
|
expected = broadcast_one_to_all(in_tree)
|
|
if not jax.tree_util.tree_all(
|
|
jax.tree_map(lambda *x: np.all(np.equal(*x)), in_tree, expected)):
|
|
raise AssertionError(
|
|
f'{fail_message} Expected: {expected}; got: {in_tree}.')
|