Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago

PiperOrigin-RevId: 520356179
This commit is contained in:
Yash Katariya 2023-03-29 09:22:34 -07:00 committed by jax authors
parent a964ae7fac
commit fbc05ee5ac
6 changed files with 26 additions and 66 deletions

View File

@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source.
* `global_arg_shapes` argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.
## jaxlib 0.4.8

View File

@ -460,7 +460,6 @@ captured using the ``xla_pmap`` primitive. Consider this example
in (k,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
global_axis_size=1
in_axes=(None, 0)
is_explicit_global_axis_size=False

View File

@ -1437,12 +1437,6 @@ def pmap(
For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
the partitioned values span multiple processes. The global cross-process
per-replica shape of each argument, i.e. does not include the leading
pmapped dimension. Can be None for replicated arguments. This API is
likely to change in the future.
Returns:
A parallelized version of ``fun`` with arguments that correspond to those of
``fun`` but with extra array axes at positions indicated by ``in_axes`` and
@ -1565,6 +1559,12 @@ def pmap(
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
[ 13. 13.]
"""
if global_arg_shapes is not None:
raise ValueError(
"global_arg_shapes only worked with sharded_jit which has long been"
" removed from JAX. Please migrate to pjit and remove global_arg_shapes"
" from pmap.")
if FLAGS.experimental_cpp_pmap:
func = _cpp_pmap
else:
@ -1579,8 +1579,7 @@ def pmap(
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes)
donate_argnums=donate_argnums)
class PmapCallInfo(NamedTuple):
@ -1591,7 +1590,6 @@ class PmapCallInfo(NamedTuple):
donated_invars: Sequence[bool]
in_axes_flat: Sequence[Optional[int]]
local_axis_size: int
global_arg_shapes_flat: Sequence[Optional[Tuple[int, ...]]]
out_axes_thunk: HashableFunction
devices: Optional[Sequence[xc.Device]]
global_axis_size: int
@ -1628,7 +1626,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
return global_axis_size
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, in_devices, backend_name,
donate_tuple, in_devices, backend_name,
axis_size, args, kwargs):
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
@ -1651,15 +1649,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
else:
dyn_in_axes = in_axes
dyn_global_arg_shapes = global_arg_shapes
if isinstance(global_arg_shapes, tuple):
dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums)
else:
dyn_global_arg_shapes = global_arg_shapes
else:
dyn_args, dyn_in_axes = args, in_axes
dyn_global_arg_shapes = global_arg_shapes
args, in_tree = tree_flatten((dyn_args, kwargs))
if donate_tuple and not config.jax_debug_nans:
@ -1667,9 +1658,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
else:
donated_invars = (False,) * len(args)
in_axes_flat = tuple(flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0)))
global_arg_shapes_flat = tuple(flatten_axes(
"pmap global_arg_shapes", in_tree, (dyn_global_arg_shapes, None),
kws=True))
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
f, res_paths = result_paths(f)
@ -1709,7 +1697,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donated_invars=donated_invars,
in_axes_flat=in_axes_flat,
local_axis_size=local_axis_size,
global_arg_shapes_flat=global_arg_shapes_flat,
out_axes_thunk=out_axes_thunk,
devices=None if in_devices is None else tuple(in_devices),
global_axis_size=global_axis_size,
@ -1727,12 +1714,11 @@ def _get_f_mapped(
backend: Optional[str],
axis_size: Optional[int],
donate_tuple: Tuple[int, ...],
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]],
):
def pmap_f(*args, **kwargs):
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, backend, axis_size, args, kwargs)
devices, backend, axis_size, args, kwargs)
for arg in p.flat_args:
dispatch.check_arg(arg)
out = pxla.xla_pmap(
@ -1741,7 +1727,6 @@ def _get_f_mapped(
devices=p.devices,
in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__, donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size)
return p.out_tree, out
@ -1780,7 +1765,6 @@ def _python_pmap(
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> stages.Wrapped:
"""The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
@ -1799,7 +1783,6 @@ def _python_pmap(
devices=devices,
backend=backend,
axis_size=axis_size,
global_arg_shapes=global_arg_shapes,
donate_tuple=donate_tuple)
out_tree, out_flat = f_pmapped_(*args, **kwargs)
@ -1807,7 +1790,7 @@ def _python_pmap(
pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple)
backend, axis_size, donate_tuple)
return cast(stages.Wrapped, pmap_f)
@ -1842,7 +1825,6 @@ def _cpp_pmap(
backend: Optional[str] = None,
axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any:
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
@ -1852,7 +1834,7 @@ def _cpp_pmap(
@api_boundary
def cache_miss(*args, **kwargs):
p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, devices, backend,
donate_tuple, devices, backend,
axis_size, args, kwargs)
for arg in p.flat_args:
dispatch.check_arg(arg)
@ -1867,7 +1849,6 @@ def _cpp_pmap(
out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
)
@ -1939,13 +1920,13 @@ def _cpp_pmap(
pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple)
backend, axis_size, donate_tuple)
return pmap_f
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
devices, backend, axis_size, global_arg_shapes, donate_tuple): # noqa: F811
devices, backend, axis_size, donate_tuple): # noqa: F811
"""Make a ``lower`` method for pmapped functions."""
# If the function we returned from ``pmap`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and
@ -1966,7 +1947,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
"""
p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, backend, axis_size, args, kwargs)
devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
@ -1976,7 +1957,6 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args,
lowering_platform=_experimental_lowering_platform)

View File

@ -745,25 +745,22 @@ def xla_pmap_impl_lazy(
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and
not is_explicit_global_axis_size and not any(d for d in donated_invars)
and not all(g is not None for g in global_arg_shapes)):
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
def _emap_apply_fn(*args):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=name, in_axes=in_axes,
out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars,
global_arg_shapes=global_arg_shapes,
is_explicit_global_axis_size=is_explicit_global_axis_size)
return _emap_apply_fn
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, *abstract_args)
# Don't re-abstractify args unless logging is enabled for performance.
@ -793,15 +790,12 @@ def _emap_impl(fun: lu.WrappedFun, *args,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
):
from jax._src import array
# TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.")
if any(g is not None for g in global_arg_shapes):
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
if is_explicit_global_axis_size:
raise NotImplementedError("Non-default global_axis_size not supported in "
"eager pmap.")
@ -1029,12 +1023,11 @@ def parallel_callable(fun: lu.WrappedFun,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
*avals):
pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals, lowering_platform=None)
pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@ -1091,26 +1084,17 @@ def find_replicas(jaxpr, axis_size, global_axis_size):
def stage_parallel_callable(
pci: ParallelCallableInfo,
fun: lu.WrappedFun,
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
fun: lu.WrappedFun):
sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals))
if any(s is not None for s in global_arg_shapes):
# TODO(skye): we could take this branch unconditionally if we handled
# grad of global_arg_shapes correctly.
global_sharded_avals = [
aval.update(shape=shape) if shape is not None else aval
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
global_sharded_avals = sharded_avals # type: ignore
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for pmap in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
@ -1133,7 +1117,7 @@ def stage_parallel_callable(
num_global_shards = replicas.num_global_replicas * parts.num_partitions
shards = ShardInfo(
sharded_avals, out_sharded_avals, global_sharded_avals,
sharded_avals, out_sharded_avals, sharded_avals,
num_local_shards, num_global_shards)
return jaxpr, consts, replicas, parts, shards
@ -1158,7 +1142,6 @@ def lower_parallel_callable(
in_axes: Iterable[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
@ -1197,8 +1180,7 @@ def lower_parallel_callable(
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
pci, fun, global_arg_shapes)
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(pci, fun)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
@ -1976,7 +1958,6 @@ def _pmap_dce_rule(used_outputs, eqn):
eqn.params['global_axis_size'], None):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
# TODO(yashkatariya,mattjj): Handle global_arg_shapes here too.
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
@ -2095,8 +2076,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
def _pmap_lowering(ctx, *in_nodes, axis_name,
axis_size, global_axis_size, devices, name,
call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, global_arg_shapes,
is_explicit_global_axis_size):
donated_invars, is_explicit_global_axis_size):
del donated_invars # Unused.
xla.check_backend_matches(backend, ctx.module_context.platform)
# We in-line here rather than generating a Call HLO as in the xla_call

View File

@ -4272,14 +4272,13 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
_identity_fn, None, (), (), sharded_dim, sharded_dim)
p = api._prepare_pmap(
_identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple,
donate_tuple, None, None, None, None, args, kwargs)
donate_tuple, None, None, None, args, kwargs)
out_flat = pxla.xla_pmap_impl(
p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices, in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
)
return tree_util.tree_unflatten(p.out_tree(), out_flat)

View File

@ -2999,7 +2999,6 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
in (c, f, g) }
devices=None
donated_invars=(False, False, False)
global_arg_shapes=(None,)
global_axis_size=None
in_axes=(0, 0, 0)
name=<lambda>