mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
* Make broadcast_one_to_all work with jax.jit rather than pmap and add a pytree test to broadcast_one_to_all
* Propagate static_broadcasted_tuple and donated_tuple to `_pmapped` so that when local_axis_size is computed in _prepare_pmap, it takes the static arguments into account. PiperOrigin-RevId: 527076648
This commit is contained in:
parent
5c73b3153a
commit
06cdcfa6eb
@ -20,6 +20,7 @@ import zlib
|
||||
|
||||
from typing import Any
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -35,10 +36,8 @@ from jax._src import config as config_internal
|
||||
import numpy as np
|
||||
|
||||
|
||||
# This needs to be top-level for the jax compilation cache.
|
||||
@functools.partial(jax.pmap, axis_name='hosts')
|
||||
def _psum(x: Any) -> Any:
|
||||
return jax.lax.psum(x, 'hosts')
|
||||
return jax.tree_map(functools.partial(jnp.sum, axis=0), x)
|
||||
|
||||
|
||||
def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
|
||||
@ -58,22 +57,26 @@ def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
|
||||
if is_source is None:
|
||||
is_source = jax.process_index() == 0
|
||||
|
||||
def pre_pmap(x):
|
||||
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
||||
jax.local_device_count())
|
||||
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
||||
pspec = P('processes')
|
||||
|
||||
def pre_jit(x):
|
||||
if is_source:
|
||||
return np.concatenate([
|
||||
x[None, ...],
|
||||
np.repeat([np.zeros_like(x)],
|
||||
jax.local_device_count() - 1, 0)
|
||||
])
|
||||
inp = x
|
||||
else:
|
||||
return np.repeat([np.zeros_like(x)], jax.local_device_count(), 0)
|
||||
inp = np.zeros_like(x)
|
||||
inp = np.expand_dims(inp, axis=0)
|
||||
return host_local_array_to_global_array(inp, global_mesh, pspec)
|
||||
|
||||
def post_pmap(x):
|
||||
return jax.device_get(x)[0]
|
||||
def post_jit(x):
|
||||
return np.asarray(x.addressable_data(0))
|
||||
|
||||
in_tree = jax.tree_util.tree_map(pre_pmap, in_tree)
|
||||
in_tree = jax.device_get(_psum(in_tree))
|
||||
return jax.tree_util.tree_map(post_pmap, in_tree)
|
||||
in_tree = jax.tree_map(pre_jit, in_tree)
|
||||
out_tree = pjit(_psum, out_shardings=jax.sharding.NamedSharding(
|
||||
global_mesh, P()))(in_tree)
|
||||
return jax.tree_map(post_jit, out_tree)
|
||||
|
||||
|
||||
def sync_global_devices(name: str):
|
||||
|
@ -1177,21 +1177,21 @@ pe.dce_rules[shard_map_p] = _shard_map_dce
|
||||
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
|
||||
static_broadcasted_argnums=(), devices=None, backend=None,
|
||||
axis_size=None, donate_argnums=(), global_arg_shapes=None):
|
||||
if axis_size is not None: # TODO what even is this?
|
||||
raise NotImplementedError
|
||||
devices = tuple(devices) if devices is not None else devices
|
||||
axis_name, _, _ = _shared_code_pmap(
|
||||
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
|
||||
f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes)
|
||||
return jax.jit(
|
||||
HashablePartial(_pmapped, (f, axis_name, in_axes, out_axes, devices,
|
||||
backend)),
|
||||
backend, axis_size, static_broadcasted_tuple,
|
||||
donate_tuple)),
|
||||
static_argnums=static_broadcasted_argnums,
|
||||
donate_argnums=donate_argnums)
|
||||
|
||||
def _pmapped(metadata, *args, **kwargs):
|
||||
f, axis_name, in_axes, out_axes, devices, backend = metadata
|
||||
p = _prepare_pmap(f, in_axes, out_axes, (), (), devices, backend, None,
|
||||
args, kwargs)
|
||||
(f, axis_name, in_axes, out_axes, devices, backend, axis_size,
|
||||
static_broadcasted_tuple, donate_tuple) = metadata
|
||||
p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
in_specs = tuple(map(partial(_axis_to_spec, axis_name), p.in_axes_flat))
|
||||
out_specs = lambda: map(partial(_axis_to_spec, axis_name), p.out_axes_thunk())
|
||||
fun = _handle_reshapes(p.flat_fun, p.in_axes_flat, p.out_axes_thunk)
|
||||
|
Loading…
x
Reference in New Issue
Block a user