Add support for multi-host partitioning when using pmap(sharded_jit).

This extends the pmap logic in a way similar to
https://github.com/google/jax/pull/4746. The new arguments to
sharded_jit specifying the local partitioning can be reused by pmap,
but with one wrinkle: the pmap implementation needs to trace its jaxpr
to "see" the sharded_jit and get these values, but it needs to know
the global aval shapes in order to correctly trace through the
sharded_jit. For now, we simply add this information as a new
"global_arg_shapes" argument to pmap. Ideally we'll replace this with
a more elegant solution, e.g. global-view device arrays.
This commit is contained in:
Skye Wanderman-Milne 2020-11-13 13:30:29 -08:00
parent ef5218f646
commit 4e56cf965a
3 changed files with 102 additions and 30 deletions

View File

@ -453,6 +453,7 @@ captured using the ``xla_pmap`` primitive. Consider this example
in (g,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
global_axis_size=None
in_axes=(None, 0)
name=inner ] b a

View File

@ -1239,7 +1239,9 @@ def pmap(fun: Callable[..., T],
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
devices=None, backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = ()) -> Callable[..., T]:
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None
) -> Callable[..., T]:
"""Parallel map with support for collective operations.
The purpose of :py:func:`pmap` is to express single-program multiple-data (SPMD)
@ -1315,6 +1317,11 @@ def pmap(fun: Callable[..., T],
for example recycling one of your input buffers to store a result. You
should not re-use buffers that you donate to a computation, JAX will raise
an error if you try to.
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
the partitioned values span multiple processes. The global cross-process
per-replica shape of each argument, i.e. does not include the leading
pmapped dimension. Can be None for replicated arguments. This API is
likely to change in the future.
Returns:
A parallelized version of ``fun`` with arguments that correspond to those of
@ -1463,18 +1470,29 @@ def pmap(fun: Callable[..., T],
dyn_argnums = [i for i in range(len(args))
if i not in static_broadcasted_tuple]
f, dyn_args = argnums_partial(f, dyn_argnums, args)
if isinstance(in_axes, tuple):
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
else:
dyn_in_axes = in_axes
dyn_global_arg_shapes = global_arg_shapes
if isinstance(global_arg_shapes, tuple):
dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums)
else:
dyn_global_arg_shapes = global_arg_shapes
else:
dyn_args, dyn_in_axes = args, in_axes
dyn_global_arg_shapes = global_arg_shapes
args, in_tree = tree_flatten((dyn_args, kwargs))
if donate_tuple:
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
else:
donated_invars = (False,) * len(args)
in_axes_flat = flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0))
global_arg_shapes_flat = flatten_axes("pmap global_arg_shapes", in_tree,
(dyn_global_arg_shapes, None))
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap")
for arg in args: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
@ -1483,7 +1501,8 @@ def pmap(fun: Callable[..., T],
axis_size=local_axis_size, global_axis_size=axis_size,
devices=None if devices is None else tuple(devices),
in_axes=tuple(in_axes_flat),
name=flat_fun.__name__, donated_invars=tuple(donated_invars))
name=flat_fun.__name__, donated_invars=tuple(donated_invars),
global_arg_shapes=tuple(global_arg_shapes_flat))
return tree_unflatten(out_tree(), out)
return f_pmapped

View File

@ -504,11 +504,13 @@ xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size,
global_axis_size, devices, name, in_axes, donated_invars):
global_axis_size, devices, name, in_axes, donated_invars,
global_arg_shapes):
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
global_axis_size, devices, name, in_axes,
donated_invars, *abstract_args)
donated_invars, global_arg_shapes,
*abstract_args)
return compiled_fun(*args)
@lu.cache
@ -521,6 +523,7 @@ def parallel_callable(fun: lu.WrappedFun,
name: str,
in_axes: Iterable[Optional[int]],
donated_invars: Iterable[bool],
global_arg_shapes,
*avals):
if devices is not None and len(devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
@ -565,8 +568,19 @@ def parallel_callable(fun: lu.WrappedFun,
if config.omnistaging_enabled:
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
for axis, aval in zip(in_axes, avals))
if any(s is not None for s in global_arg_shapes):
# TODO(skye): we could take this branch unconditionally if we handled
# grad of global_arg_shapes correctly.
global_sharded_avals = [
ShapedArray(shape, aval.dtype) if shape is not None else aval
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
global_sharded_avals = sharded_avals # type: ignore
logging.vlog(2, "sharded_avals: %s", sharded_avals)
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, sharded_avals)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
else:
@lu.wrap_init
@ -585,6 +599,7 @@ def parallel_callable(fun: lu.WrappedFun,
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
out_pvs, out_consts = unzip2(out_pvals)
global_sharded_avals = sharded_avals # type: ignore
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
# for now raise an error
@ -618,9 +633,30 @@ def parallel_callable(fun: lu.WrappedFun,
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
num_local_replicas = axis_size * jaxpr_replicas
num_global_replicas = global_axis_size * jaxpr_replicas
arg_parts, out_parts, num_partitions = _find_partitions(jaxpr)
num_local_shards = num_local_replicas * num_partitions
(arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts,
local_num_partitions) = _find_partitions(jaxpr)
if local_num_partitions is None:
local_num_partitions = num_partitions
if local_arg_parts is None:
local_arg_parts = arg_parts
if local_out_parts is None:
local_out_parts = out_parts
logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
num_global_replicas, num_local_replicas)
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
num_partitions, local_num_partitions)
logging.vlog(2, "arg_parts: %s", arg_parts)
logging.vlog(2, "local_arg_parts: %s", local_arg_parts)
logging.vlog(2, "out_parts: %s", out_parts)
logging.vlog(2, "local_out_parts: %s", local_out_parts)
logging.vlog(2, "devices: %s", devices)
logging.vlog(2, "local_devices: %s", local_devices)
num_local_shards = num_local_replicas * local_num_partitions
num_global_shards = num_global_replicas * num_partitions
if (xb.host_count() > 1 and must_run_on_all_devices and
@ -650,16 +686,16 @@ def parallel_callable(fun: lu.WrappedFun,
logging.log(log_priority,
f"Compiling {fun.__name__} for {num_global_shards} devices with "
f"args {avals}. (num_replicas={num_global_replicas} "
f"num_partitions={num_partitions}")
f"num_partitions={num_partitions})")
axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,))
tuple_args = len(sharded_avals) > 100 # pass long arg lists as tuple for TPU
tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
xla_consts = map(partial(xb.constant, c), consts)
replicated_args = [axis is None for axis in in_axes]
xla_args, donated_invars = xla._xla_callable_args(c, sharded_avals, tuple_args,
xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
replicated_args, arg_parts,
donated_invars=donated_invars)
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
@ -688,7 +724,7 @@ def parallel_callable(fun: lu.WrappedFun,
# violating pmap's semantics where data is sharded across replicas in
# row-major order. Instead, manually create a device assignment that ensures
# each host is responsible for a continguous set of replicas.
if num_global_replicas > num_local_replicas:
if num_global_shards > num_local_shards:
# TODO(skye): use a locality-aware assignment that satisfies the above
# constraint.
devices = [d for host_id in xb.host_ids()
@ -728,24 +764,25 @@ def parallel_callable(fun: lu.WrappedFun,
compile_options.parameter_is_tupled_arguments = tuple_args
compiled = xla.backend_compile(backend, built, compile_options)
arg_parts_ = arg_parts or [None] * len(avals)
local_arg_parts_ = local_arg_parts or [None] * len(avals)
input_sharding_specs = [
_pmap_sharding_spec(num_local_replicas, axis_size, num_partitions, parts,
aval, in_axis)
_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
parts, aval, in_axis)
if aval is not core.abstract_unit else None
for aval, parts, in_axis in zip(sharded_avals, arg_parts_, in_axes)]
for aval, parts, in_axis in zip(sharded_avals, local_arg_parts_, in_axes)]
input_indices = [spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in zip(avals, input_sharding_specs)]
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
if config.omnistaging_enabled:
handle_outs = avals_to_results_handler( # type: ignore
axis_size, num_local_replicas, num_partitions, out_parts, out_avals)
axis_size, num_local_replicas, local_num_partitions, out_parts,
local_out_parts, out_avals)
else:
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore
num_partitions, out_parts,
out_pvals, compiled.local_devices(),
backend)
local_num_partitions,
local_out_parts, out_pvals,
compiled.local_devices(), backend)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
@ -757,8 +794,13 @@ PartitionsOrReplicated = Optional[Tuple[int, ...]]
def _find_partitions(jaxpr) -> Tuple[
Optional[Tuple[PartitionsOrReplicated, ...]],
Optional[Tuple[PartitionsOrReplicated, ...]],
int]:
"""Returns (in_partitions, out_partitions, num_partitions)."""
int,
Optional[Tuple[PartitionsOrReplicated, ...]],
Optional[Tuple[PartitionsOrReplicated, ...]],
Optional[int]]:
"""Returns (in_partitions, out_partitions, num_partitions, local_in_parts,
local_out_parts, local_num_partitions).
"""
for eqn in jaxpr.eqns:
if eqn.primitive.name == "sharded_call":
if len(jaxpr.eqns) > 1:
@ -766,10 +808,13 @@ def _find_partitions(jaxpr) -> Tuple[
"pmap of sharded_jit + non-sharded operations not yet implemented.")
num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"],
eqn.params["nparts"])
return (eqn.params["in_parts"], eqn.params["out_parts_thunk"](),
num_partitions)
return None, None, 1
return (eqn.params["in_parts"],
eqn.params["out_parts_thunk"](),
num_partitions,
eqn.params["local_in_parts"],
eqn.params["local_out_parts_thunk"](),
eqn.params["local_nparts"])
return None, None, 1, None, None, None
def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]):
"""Returns the total number of partitions to use.
@ -863,21 +908,28 @@ def _safe_div(x, y):
class ResultToPopulate: pass
result_to_populate = ResultToPopulate()
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
def avals_to_results_handler(size, nrep, npart, out_parts, local_out_parts,
out_avals):
nouts = len(out_avals)
if out_parts is None:
out_parts = (None,) * len(out_avals)
if local_out_parts is None:
local_out_parts = (None,) * len(out_avals)
local_out_avals = [get_local_aval(aval, parts, lparts)
for aval, parts, lparts
in safe_zip(out_avals, out_parts, local_out_parts)]
# TODO(mattjj,skyewm): can probably clean up this logic
out_axis = 0
out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, out_axis)
if aval is not core.abstract_unit else None
for parts, aval in zip(out_parts, out_avals)]
for parts, aval in zip(local_out_parts, local_out_avals)]
out_indices = [spec_to_indices(core.unmapped_aval(size, out_axis, aval).shape, spec)
if aval is not core.abstract_unit else None
for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error
for aval, spec in zip(local_out_avals, out_specs)] # pytype: disable=attribute-error
handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, out_axis, aval))
for spec, idcs, aval in zip(out_specs, out_indices, out_avals)]
for spec, idcs, aval in zip(out_specs, out_indices, local_out_avals)]
def handler(out_bufs):
assert nrep * npart == len(out_bufs)
@ -1001,7 +1053,7 @@ def _pmap_translation_rule(c, axis_env,
in_nodes, name_stack, axis_name, axis_size,
global_axis_size, devices, name,
call_jaxpr, *, backend=None, in_axes,
donated_invars):
donated_invars, global_arg_shapes):
del donated_invars # Unused.
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.