* Make broadcast_one_to_all work with jax.jit rather than pmap and add a pytree test to broadcast_one_to_all

* Propagate static_broadcasted_tuple and donated_tuple to `_pmapped` so that when local_axis_size is computed in _prepare_pmap, it takes the static arguments into account.

PiperOrigin-RevId: 527076648
This commit is contained in:
Yash Katariya 2023-04-25 14:31:49 -07:00 committed by jax authors
parent 5c73b3153a
commit 06cdcfa6eb
2 changed files with 25 additions and 22 deletions

View File

@ -20,6 +20,7 @@ import zlib
from typing import Any
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax._src import core
from jax._src import dispatch
@ -35,10 +36,8 @@ from jax._src import config as config_internal
import numpy as np
# This needs to be top-level for the jax compilation cache.
@functools.partial(jax.pmap, axis_name='hosts')
def _psum(x: Any) -> Any:
return jax.lax.psum(x, 'hosts')
return jax.tree_map(functools.partial(jnp.sum, axis=0), x)
def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
@ -58,22 +57,26 @@ def broadcast_one_to_all(in_tree: Any, is_source: Optional[bool] = None) -> Any:
if is_source is None:
is_source = jax.process_index() == 0
def pre_pmap(x):
devices = np.array(jax.devices()).reshape(jax.process_count(),
jax.local_device_count())
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
pspec = P('processes')
def pre_jit(x):
if is_source:
return np.concatenate([
x[None, ...],
np.repeat([np.zeros_like(x)],
jax.local_device_count() - 1, 0)
])
inp = x
else:
return np.repeat([np.zeros_like(x)], jax.local_device_count(), 0)
inp = np.zeros_like(x)
inp = np.expand_dims(inp, axis=0)
return host_local_array_to_global_array(inp, global_mesh, pspec)
def post_pmap(x):
return jax.device_get(x)[0]
def post_jit(x):
return np.asarray(x.addressable_data(0))
in_tree = jax.tree_util.tree_map(pre_pmap, in_tree)
in_tree = jax.device_get(_psum(in_tree))
return jax.tree_util.tree_map(post_pmap, in_tree)
in_tree = jax.tree_map(pre_jit, in_tree)
out_tree = pjit(_psum, out_shardings=jax.sharding.NamedSharding(
global_mesh, P()))(in_tree)
return jax.tree_map(post_jit, out_tree)
def sync_global_devices(name: str):

View File

@ -1177,21 +1177,21 @@ pe.dce_rules[shard_map_p] = _shard_map_dce
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
static_broadcasted_argnums=(), devices=None, backend=None,
axis_size=None, donate_argnums=(), global_arg_shapes=None):
if axis_size is not None: # TODO what even is this?
raise NotImplementedError
devices = tuple(devices) if devices is not None else devices
axis_name, _, _ = _shared_code_pmap(
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes)
return jax.jit(
HashablePartial(_pmapped, (f, axis_name, in_axes, out_axes, devices,
backend)),
backend, axis_size, static_broadcasted_tuple,
donate_tuple)),
static_argnums=static_broadcasted_argnums,
donate_argnums=donate_argnums)
def _pmapped(metadata, *args, **kwargs):
f, axis_name, in_axes, out_axes, devices, backend = metadata
p = _prepare_pmap(f, in_axes, out_axes, (), (), devices, backend, None,
args, kwargs)
(f, axis_name, in_axes, out_axes, devices, backend, axis_size,
static_broadcasted_tuple, donate_tuple) = metadata
p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
in_specs = tuple(map(partial(_axis_to_spec, axis_name), p.in_axes_flat))
out_specs = lambda: map(partial(_axis_to_spec, axis_name), p.out_axes_thunk())
fun = _handle_reshapes(p.flat_fun, p.in_axes_flat, p.out_axes_thunk)