mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for multi-host partitioning when using pmap(sharded_jit).
This extends the pmap logic in a way similar to https://github.com/google/jax/pull/4746. The new arguments to sharded_jit specifying the local partitioning can be reused by pmap, but with one wrinkle: the pmap implementation needs to trace its jaxpr to "see" the sharded_jit and get these values, but it needs to know the global aval shapes in order to correctly trace through the sharded_jit. For now, we simply add this information as a new "global_arg_shapes" argument to pmap. Ideally we'll replace this with a more elegant solution, e.g. global-view device arrays.
This commit is contained in:
parent
ef5218f646
commit
4e56cf965a
@ -453,6 +453,7 @@ captured using the ``xla_pmap`` primitive. Consider this example
|
||||
in (g,) }
|
||||
devices=None
|
||||
donated_invars=(False, False)
|
||||
global_arg_shapes=(None,)
|
||||
global_axis_size=None
|
||||
in_axes=(None, 0)
|
||||
name=inner ] b a
|
||||
|
23
jax/api.py
23
jax/api.py
@ -1239,7 +1239,9 @@ def pmap(fun: Callable[..., T],
|
||||
static_broadcasted_argnums: Union[int, Iterable[int]] = (),
|
||||
devices=None, backend: Optional[str] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = ()) -> Callable[..., T]:
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None
|
||||
) -> Callable[..., T]:
|
||||
"""Parallel map with support for collective operations.
|
||||
|
||||
The purpose of :py:func:`pmap` is to express single-program multiple-data (SPMD)
|
||||
@ -1315,6 +1317,11 @@ def pmap(fun: Callable[..., T],
|
||||
for example recycling one of your input buffers to store a result. You
|
||||
should not re-use buffers that you donate to a computation, JAX will raise
|
||||
an error if you try to.
|
||||
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
|
||||
the partitioned values span multiple processes. The global cross-process
|
||||
per-replica shape of each argument, i.e. does not include the leading
|
||||
pmapped dimension. Can be None for replicated arguments. This API is
|
||||
likely to change in the future.
|
||||
|
||||
Returns:
|
||||
A parallelized version of ``fun`` with arguments that correspond to those of
|
||||
@ -1463,18 +1470,29 @@ def pmap(fun: Callable[..., T],
|
||||
dyn_argnums = [i for i in range(len(args))
|
||||
if i not in static_broadcasted_tuple]
|
||||
f, dyn_args = argnums_partial(f, dyn_argnums, args)
|
||||
|
||||
if isinstance(in_axes, tuple):
|
||||
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
|
||||
else:
|
||||
dyn_in_axes = in_axes
|
||||
dyn_global_arg_shapes = global_arg_shapes
|
||||
|
||||
if isinstance(global_arg_shapes, tuple):
|
||||
dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums)
|
||||
else:
|
||||
dyn_global_arg_shapes = global_arg_shapes
|
||||
else:
|
||||
dyn_args, dyn_in_axes = args, in_axes
|
||||
dyn_global_arg_shapes = global_arg_shapes
|
||||
args, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
|
||||
if donate_tuple:
|
||||
donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args)
|
||||
in_axes_flat = flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0))
|
||||
global_arg_shapes_flat = flatten_axes("pmap global_arg_shapes", in_tree,
|
||||
(dyn_global_arg_shapes, None))
|
||||
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap")
|
||||
for arg in args: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
@ -1483,7 +1501,8 @@ def pmap(fun: Callable[..., T],
|
||||
axis_size=local_axis_size, global_axis_size=axis_size,
|
||||
devices=None if devices is None else tuple(devices),
|
||||
in_axes=tuple(in_axes_flat),
|
||||
name=flat_fun.__name__, donated_invars=tuple(donated_invars))
|
||||
name=flat_fun.__name__, donated_invars=tuple(donated_invars),
|
||||
global_arg_shapes=tuple(global_arg_shapes_flat))
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
return f_pmapped
|
||||
|
@ -504,11 +504,13 @@ xla.canonicalize_dtype_handlers[ShardedDeviceArray] = identity
|
||||
### the xla_pmap primitive and its rules are comparable to xla_call in xla.py
|
||||
|
||||
def xla_pmap_impl(fun: lu.WrappedFun, *args, backend, axis_name, axis_size,
|
||||
global_axis_size, devices, name, in_axes, donated_invars):
|
||||
global_axis_size, devices, name, in_axes, donated_invars,
|
||||
global_arg_shapes):
|
||||
abstract_args = unsafe_map(xla.abstractify, args)
|
||||
compiled_fun = parallel_callable(fun, backend, axis_name, axis_size,
|
||||
global_axis_size, devices, name, in_axes,
|
||||
donated_invars, *abstract_args)
|
||||
donated_invars, global_arg_shapes,
|
||||
*abstract_args)
|
||||
return compiled_fun(*args)
|
||||
|
||||
@lu.cache
|
||||
@ -521,6 +523,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
name: str,
|
||||
in_axes: Iterable[Optional[int]],
|
||||
donated_invars: Iterable[bool],
|
||||
global_arg_shapes,
|
||||
*avals):
|
||||
if devices is not None and len(devices) == 0:
|
||||
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
@ -565,8 +568,19 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
if config.omnistaging_enabled:
|
||||
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in zip(in_axes, avals))
|
||||
if any(s is not None for s in global_arg_shapes):
|
||||
# TODO(skye): we could take this branch unconditionally if we handled
|
||||
# grad of global_arg_shapes correctly.
|
||||
global_sharded_avals = [
|
||||
ShapedArray(shape, aval.dtype) if shape is not None else aval
|
||||
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
|
||||
else:
|
||||
global_sharded_avals = sharded_avals # type: ignore
|
||||
logging.vlog(2, "sharded_avals: %s", sharded_avals)
|
||||
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
|
||||
|
||||
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, sharded_avals)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
else:
|
||||
@lu.wrap_init
|
||||
@ -585,6 +599,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_pvs, out_consts = unzip2(out_pvals)
|
||||
global_sharded_avals = sharded_avals # type: ignore
|
||||
|
||||
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
|
||||
# for now raise an error
|
||||
@ -618,9 +633,30 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
|
||||
num_local_replicas = axis_size * jaxpr_replicas
|
||||
num_global_replicas = global_axis_size * jaxpr_replicas
|
||||
arg_parts, out_parts, num_partitions = _find_partitions(jaxpr)
|
||||
|
||||
num_local_shards = num_local_replicas * num_partitions
|
||||
(arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts,
|
||||
local_num_partitions) = _find_partitions(jaxpr)
|
||||
|
||||
if local_num_partitions is None:
|
||||
local_num_partitions = num_partitions
|
||||
|
||||
if local_arg_parts is None:
|
||||
local_arg_parts = arg_parts
|
||||
if local_out_parts is None:
|
||||
local_out_parts = out_parts
|
||||
|
||||
logging.vlog(2, "num_replicas: %d num_local_replicas: %d",
|
||||
num_global_replicas, num_local_replicas)
|
||||
logging.vlog(2, "num_partitions: %d local_num_partitions: %d",
|
||||
num_partitions, local_num_partitions)
|
||||
logging.vlog(2, "arg_parts: %s", arg_parts)
|
||||
logging.vlog(2, "local_arg_parts: %s", local_arg_parts)
|
||||
logging.vlog(2, "out_parts: %s", out_parts)
|
||||
logging.vlog(2, "local_out_parts: %s", local_out_parts)
|
||||
logging.vlog(2, "devices: %s", devices)
|
||||
logging.vlog(2, "local_devices: %s", local_devices)
|
||||
|
||||
num_local_shards = num_local_replicas * local_num_partitions
|
||||
num_global_shards = num_global_replicas * num_partitions
|
||||
|
||||
if (xb.host_count() > 1 and must_run_on_all_devices and
|
||||
@ -650,16 +686,16 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
logging.log(log_priority,
|
||||
f"Compiling {fun.__name__} for {num_global_shards} devices with "
|
||||
f"args {avals}. (num_replicas={num_global_replicas} "
|
||||
f"num_partitions={num_partitions}")
|
||||
f"num_partitions={num_partitions})")
|
||||
|
||||
axis_env = xla.AxisEnv(num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
|
||||
tuple_args = len(sharded_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
tuple_args = len(global_sharded_avals) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder("pmap_{}".format(fun.__name__))
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
xla_args, donated_invars = xla._xla_callable_args(c, sharded_avals, tuple_args,
|
||||
xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
|
||||
replicated_args, arg_parts,
|
||||
donated_invars=donated_invars)
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
@ -688,7 +724,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
# violating pmap's semantics where data is sharded across replicas in
|
||||
# row-major order. Instead, manually create a device assignment that ensures
|
||||
# each host is responsible for a continguous set of replicas.
|
||||
if num_global_replicas > num_local_replicas:
|
||||
if num_global_shards > num_local_shards:
|
||||
# TODO(skye): use a locality-aware assignment that satisfies the above
|
||||
# constraint.
|
||||
devices = [d for host_id in xb.host_ids()
|
||||
@ -728,24 +764,25 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = xla.backend_compile(backend, built, compile_options)
|
||||
|
||||
arg_parts_ = arg_parts or [None] * len(avals)
|
||||
local_arg_parts_ = local_arg_parts or [None] * len(avals)
|
||||
input_sharding_specs = [
|
||||
_pmap_sharding_spec(num_local_replicas, axis_size, num_partitions, parts,
|
||||
aval, in_axis)
|
||||
_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
|
||||
parts, aval, in_axis)
|
||||
if aval is not core.abstract_unit else None
|
||||
for aval, parts, in_axis in zip(sharded_avals, arg_parts_, in_axes)]
|
||||
for aval, parts, in_axis in zip(sharded_avals, local_arg_parts_, in_axes)]
|
||||
input_indices = [spec_to_indices(aval.shape, spec)
|
||||
if spec is not None else None
|
||||
for aval, spec in zip(avals, input_sharding_specs)]
|
||||
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
|
||||
if config.omnistaging_enabled:
|
||||
handle_outs = avals_to_results_handler( # type: ignore
|
||||
axis_size, num_local_replicas, num_partitions, out_parts, out_avals)
|
||||
axis_size, num_local_replicas, local_num_partitions, out_parts,
|
||||
local_out_parts, out_avals)
|
||||
else:
|
||||
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore
|
||||
num_partitions, out_parts,
|
||||
out_pvals, compiled.local_devices(),
|
||||
backend)
|
||||
local_num_partitions,
|
||||
local_out_parts, out_pvals,
|
||||
compiled.local_devices(), backend)
|
||||
|
||||
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
|
||||
@ -757,8 +794,13 @@ PartitionsOrReplicated = Optional[Tuple[int, ...]]
|
||||
def _find_partitions(jaxpr) -> Tuple[
|
||||
Optional[Tuple[PartitionsOrReplicated, ...]],
|
||||
Optional[Tuple[PartitionsOrReplicated, ...]],
|
||||
int]:
|
||||
"""Returns (in_partitions, out_partitions, num_partitions)."""
|
||||
int,
|
||||
Optional[Tuple[PartitionsOrReplicated, ...]],
|
||||
Optional[Tuple[PartitionsOrReplicated, ...]],
|
||||
Optional[int]]:
|
||||
"""Returns (in_partitions, out_partitions, num_partitions, local_in_parts,
|
||||
local_out_parts, local_num_partitions).
|
||||
"""
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive.name == "sharded_call":
|
||||
if len(jaxpr.eqns) > 1:
|
||||
@ -766,10 +808,13 @@ def _find_partitions(jaxpr) -> Tuple[
|
||||
"pmap of sharded_jit + non-sharded operations not yet implemented.")
|
||||
num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"],
|
||||
eqn.params["nparts"])
|
||||
return (eqn.params["in_parts"], eqn.params["out_parts_thunk"](),
|
||||
num_partitions)
|
||||
return None, None, 1
|
||||
|
||||
return (eqn.params["in_parts"],
|
||||
eqn.params["out_parts_thunk"](),
|
||||
num_partitions,
|
||||
eqn.params["local_in_parts"],
|
||||
eqn.params["local_out_parts_thunk"](),
|
||||
eqn.params["local_nparts"])
|
||||
return None, None, 1, None, None, None
|
||||
|
||||
def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]):
|
||||
"""Returns the total number of partitions to use.
|
||||
@ -863,21 +908,28 @@ def _safe_div(x, y):
|
||||
class ResultToPopulate: pass
|
||||
result_to_populate = ResultToPopulate()
|
||||
|
||||
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
|
||||
def avals_to_results_handler(size, nrep, npart, out_parts, local_out_parts,
|
||||
out_avals):
|
||||
nouts = len(out_avals)
|
||||
if out_parts is None:
|
||||
out_parts = (None,) * len(out_avals)
|
||||
if local_out_parts is None:
|
||||
local_out_parts = (None,) * len(out_avals)
|
||||
|
||||
local_out_avals = [get_local_aval(aval, parts, lparts)
|
||||
for aval, parts, lparts
|
||||
in safe_zip(out_avals, out_parts, local_out_parts)]
|
||||
|
||||
# TODO(mattjj,skyewm): can probably clean up this logic
|
||||
out_axis = 0
|
||||
out_specs = [_pmap_sharding_spec(nrep, size, npart, parts, aval, out_axis)
|
||||
if aval is not core.abstract_unit else None
|
||||
for parts, aval in zip(out_parts, out_avals)]
|
||||
for parts, aval in zip(local_out_parts, local_out_avals)]
|
||||
out_indices = [spec_to_indices(core.unmapped_aval(size, out_axis, aval).shape, spec)
|
||||
if aval is not core.abstract_unit else None
|
||||
for aval, spec in zip(out_avals, out_specs)] # pytype: disable=attribute-error
|
||||
for aval, spec in zip(local_out_avals, out_specs)] # pytype: disable=attribute-error
|
||||
handlers = [aval_to_result_handler(spec, idcs, core.unmapped_aval(size, out_axis, aval))
|
||||
for spec, idcs, aval in zip(out_specs, out_indices, out_avals)]
|
||||
for spec, idcs, aval in zip(out_specs, out_indices, local_out_avals)]
|
||||
|
||||
def handler(out_bufs):
|
||||
assert nrep * npart == len(out_bufs)
|
||||
@ -1001,7 +1053,7 @@ def _pmap_translation_rule(c, axis_env,
|
||||
in_nodes, name_stack, axis_name, axis_size,
|
||||
global_axis_size, devices, name,
|
||||
call_jaxpr, *, backend=None, in_axes,
|
||||
donated_invars):
|
||||
donated_invars, global_arg_shapes):
|
||||
del donated_invars # Unused.
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
|
Loading…
x
Reference in New Issue
Block a user