mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
This commit is contained in:
parent
a964ae7fac
commit
fbc05ee5ac
@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
|
||||
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
|
||||
from source.
|
||||
* `global_arg_shapes` argument of pmap only worked with sharded_jit and has
|
||||
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
|
||||
from pmap.
|
||||
|
||||
## jaxlib 0.4.8
|
||||
|
||||
|
@ -460,7 +460,6 @@ captured using the ``xla_pmap`` primitive. Consider this example
|
||||
in (k,) }
|
||||
devices=None
|
||||
donated_invars=(False, False)
|
||||
global_arg_shapes=(None,)
|
||||
global_axis_size=1
|
||||
in_axes=(None, 0)
|
||||
is_explicit_global_axis_size=False
|
||||
|
@ -1437,12 +1437,6 @@ def pmap(
|
||||
For more details on buffer donation see the
|
||||
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
|
||||
|
||||
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
|
||||
the partitioned values span multiple processes. The global cross-process
|
||||
per-replica shape of each argument, i.e. does not include the leading
|
||||
pmapped dimension. Can be None for replicated arguments. This API is
|
||||
likely to change in the future.
|
||||
|
||||
Returns:
|
||||
A parallelized version of ``fun`` with arguments that correspond to those of
|
||||
``fun`` but with extra array axes at positions indicated by ``in_axes`` and
|
||||
@ -1565,6 +1559,12 @@ def pmap(
|
||||
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
|
||||
[ 13. 13.]
|
||||
"""
|
||||
if global_arg_shapes is not None:
|
||||
raise ValueError(
|
||||
"global_arg_shapes only worked with sharded_jit which has long been"
|
||||
" removed from JAX. Please migrate to pjit and remove global_arg_shapes"
|
||||
" from pmap.")
|
||||
|
||||
if FLAGS.experimental_cpp_pmap:
|
||||
func = _cpp_pmap
|
||||
else:
|
||||
@ -1579,8 +1579,7 @@ def pmap(
|
||||
devices=devices,
|
||||
backend=backend,
|
||||
axis_size=axis_size,
|
||||
donate_argnums=donate_argnums,
|
||||
global_arg_shapes=global_arg_shapes)
|
||||
donate_argnums=donate_argnums)
|
||||
|
||||
|
||||
class PmapCallInfo(NamedTuple):
|
||||
@ -1591,7 +1590,6 @@ class PmapCallInfo(NamedTuple):
|
||||
donated_invars: Sequence[bool]
|
||||
in_axes_flat: Sequence[Optional[int]]
|
||||
local_axis_size: int
|
||||
global_arg_shapes_flat: Sequence[Optional[Tuple[int, ...]]]
|
||||
out_axes_thunk: HashableFunction
|
||||
devices: Optional[Sequence[xc.Device]]
|
||||
global_axis_size: int
|
||||
@ -1628,7 +1626,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
|
||||
return global_axis_size
|
||||
|
||||
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donate_tuple, global_arg_shapes, in_devices, backend_name,
|
||||
donate_tuple, in_devices, backend_name,
|
||||
axis_size, args, kwargs):
|
||||
if in_devices is not None and len(in_devices) == 0:
|
||||
raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
@ -1651,15 +1649,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
|
||||
else:
|
||||
dyn_in_axes = in_axes
|
||||
dyn_global_arg_shapes = global_arg_shapes
|
||||
|
||||
if isinstance(global_arg_shapes, tuple):
|
||||
dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums)
|
||||
else:
|
||||
dyn_global_arg_shapes = global_arg_shapes
|
||||
else:
|
||||
dyn_args, dyn_in_axes = args, in_axes
|
||||
dyn_global_arg_shapes = global_arg_shapes
|
||||
args, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
|
||||
if donate_tuple and not config.jax_debug_nans:
|
||||
@ -1667,9 +1658,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
else:
|
||||
donated_invars = (False,) * len(args)
|
||||
in_axes_flat = tuple(flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0)))
|
||||
global_arg_shapes_flat = tuple(flatten_axes(
|
||||
"pmap global_arg_shapes", in_tree, (dyn_global_arg_shapes, None),
|
||||
kws=True))
|
||||
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
|
||||
|
||||
f, res_paths = result_paths(f)
|
||||
@ -1709,7 +1697,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donated_invars=donated_invars,
|
||||
in_axes_flat=in_axes_flat,
|
||||
local_axis_size=local_axis_size,
|
||||
global_arg_shapes_flat=global_arg_shapes_flat,
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
devices=None if in_devices is None else tuple(in_devices),
|
||||
global_axis_size=global_axis_size,
|
||||
@ -1727,12 +1714,11 @@ def _get_f_mapped(
|
||||
backend: Optional[str],
|
||||
axis_size: Optional[int],
|
||||
donate_tuple: Tuple[int, ...],
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]],
|
||||
):
|
||||
def pmap_f(*args, **kwargs):
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
global_arg_shapes, devices, backend, axis_size, args, kwargs)
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
for arg in p.flat_args:
|
||||
dispatch.check_arg(arg)
|
||||
out = pxla.xla_pmap(
|
||||
@ -1741,7 +1727,6 @@ def _get_f_mapped(
|
||||
devices=p.devices,
|
||||
in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk,
|
||||
name=p.flat_fun.__name__, donated_invars=p.donated_invars,
|
||||
global_arg_shapes=p.global_arg_shapes_flat,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size)
|
||||
return p.out_tree, out
|
||||
|
||||
@ -1780,7 +1765,6 @@ def _python_pmap(
|
||||
backend: Optional[str] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
||||
) -> stages.Wrapped:
|
||||
"""The Python only implementation."""
|
||||
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
|
||||
@ -1799,7 +1783,6 @@ def _python_pmap(
|
||||
devices=devices,
|
||||
backend=backend,
|
||||
axis_size=axis_size,
|
||||
global_arg_shapes=global_arg_shapes,
|
||||
donate_tuple=donate_tuple)
|
||||
|
||||
out_tree, out_flat = f_pmapped_(*args, **kwargs)
|
||||
@ -1807,7 +1790,7 @@ def _python_pmap(
|
||||
|
||||
pmap_f.lower = _pmap_lower(
|
||||
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
|
||||
backend, axis_size, global_arg_shapes, donate_tuple)
|
||||
backend, axis_size, donate_tuple)
|
||||
|
||||
return cast(stages.Wrapped, pmap_f)
|
||||
|
||||
@ -1842,7 +1825,6 @@ def _cpp_pmap(
|
||||
backend: Optional[str] = None,
|
||||
axis_size: Optional[int] = None,
|
||||
donate_argnums: Union[int, Iterable[int]] = (),
|
||||
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
||||
) -> Any:
|
||||
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
|
||||
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
|
||||
@ -1852,7 +1834,7 @@ def _cpp_pmap(
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donate_tuple, global_arg_shapes, devices, backend,
|
||||
donate_tuple, devices, backend,
|
||||
axis_size, args, kwargs)
|
||||
for arg in p.flat_args:
|
||||
dispatch.check_arg(arg)
|
||||
@ -1867,7 +1849,6 @@ def _cpp_pmap(
|
||||
out_axes_thunk=p.out_axes_thunk,
|
||||
name=p.flat_fun.__name__,
|
||||
donated_invars=p.donated_invars,
|
||||
global_arg_shapes=p.global_arg_shapes_flat,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
)
|
||||
|
||||
@ -1939,13 +1920,13 @@ def _cpp_pmap(
|
||||
|
||||
pmap_f.lower = _pmap_lower(
|
||||
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
|
||||
backend, axis_size, global_arg_shapes, donate_tuple)
|
||||
backend, axis_size, donate_tuple)
|
||||
|
||||
return pmap_f
|
||||
|
||||
|
||||
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
devices, backend, axis_size, global_arg_shapes, donate_tuple): # noqa: F811
|
||||
devices, backend, axis_size, donate_tuple): # noqa: F811
|
||||
"""Make a ``lower`` method for pmapped functions."""
|
||||
# If the function we returned from ``pmap`` were a class instance,
|
||||
# this might naturally be a method, with ``fun`` as a ``self`` and
|
||||
@ -1966,7 +1947,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
"""
|
||||
p = _prepare_pmap(
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
global_arg_shapes, devices, backend, axis_size, args, kwargs)
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
abstract_args = list(map(shaped_abstractify, p.flat_args))
|
||||
computation = pxla.lower_parallel_callable(
|
||||
p.flat_fun, backend, axis_name,
|
||||
@ -1976,7 +1957,6 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
in_axes=p.in_axes_flat,
|
||||
out_axes_thunk=p.out_axes_thunk,
|
||||
donated_invars=p.donated_invars,
|
||||
global_arg_shapes=p.global_arg_shapes_flat,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
avals=abstract_args,
|
||||
lowering_platform=_experimental_lowering_platform)
|
||||
|
@ -745,25 +745,22 @@ def xla_pmap_impl_lazy(
|
||||
in_axes: Sequence[Optional[int]],
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
||||
is_explicit_global_axis_size: bool,
|
||||
) -> Callable:
|
||||
if (config.jax_disable_jit and config.jax_eager_pmap and
|
||||
not is_explicit_global_axis_size and not any(d for d in donated_invars)
|
||||
and not all(g is not None for g in global_arg_shapes)):
|
||||
not is_explicit_global_axis_size and not any(d for d in donated_invars)):
|
||||
def _emap_apply_fn(*args):
|
||||
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
|
||||
axis_size=axis_size, global_axis_size=global_axis_size,
|
||||
devices=devices, name=name, in_axes=in_axes,
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
donated_invars=donated_invars,
|
||||
global_arg_shapes=global_arg_shapes,
|
||||
is_explicit_global_axis_size=is_explicit_global_axis_size)
|
||||
return _emap_apply_fn
|
||||
abstract_args = unsafe_map(xla.abstractify, args)
|
||||
compiled_fun, fingerprint = parallel_callable(
|
||||
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
is_explicit_global_axis_size, *abstract_args)
|
||||
|
||||
# Don't re-abstractify args unless logging is enabled for performance.
|
||||
@ -793,15 +790,12 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
in_axes: Sequence[Optional[int]],
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
||||
is_explicit_global_axis_size: bool,
|
||||
):
|
||||
from jax._src import array
|
||||
# TODO(sharadmv,mattjj): implement these cases
|
||||
if any(d for d in donated_invars):
|
||||
raise NotImplementedError("Buffer donation not supported in eager pmap.")
|
||||
if any(g is not None for g in global_arg_shapes):
|
||||
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
|
||||
if is_explicit_global_axis_size:
|
||||
raise NotImplementedError("Non-default global_axis_size not supported in "
|
||||
"eager pmap.")
|
||||
@ -1029,12 +1023,11 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
in_axes: Sequence[Optional[int]],
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
||||
is_explicit_global_axis_size: bool,
|
||||
*avals):
|
||||
pmap_computation = lower_parallel_callable(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars, global_arg_shapes,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
is_explicit_global_axis_size, avals, lowering_platform=None)
|
||||
pmap_executable = pmap_computation.compile()
|
||||
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
|
||||
@ -1091,26 +1084,17 @@ def find_replicas(jaxpr, axis_size, global_axis_size):
|
||||
|
||||
def stage_parallel_callable(
|
||||
pci: ParallelCallableInfo,
|
||||
fun: lu.WrappedFun,
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
||||
fun: lu.WrappedFun):
|
||||
sharded_avals = tuple(
|
||||
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
||||
if any(s is not None for s in global_arg_shapes):
|
||||
# TODO(skye): we could take this branch unconditionally if we handled
|
||||
# grad of global_arg_shapes correctly.
|
||||
global_sharded_avals = [
|
||||
aval.update(shape=shape) if shape is not None else aval
|
||||
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
|
||||
else:
|
||||
global_sharded_avals = sharded_avals # type: ignore
|
||||
|
||||
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for pmap in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
@ -1133,7 +1117,7 @@ def stage_parallel_callable(
|
||||
num_global_shards = replicas.num_global_replicas * parts.num_partitions
|
||||
|
||||
shards = ShardInfo(
|
||||
sharded_avals, out_sharded_avals, global_sharded_avals,
|
||||
sharded_avals, out_sharded_avals, sharded_avals,
|
||||
num_local_shards, num_global_shards)
|
||||
|
||||
return jaxpr, consts, replicas, parts, shards
|
||||
@ -1158,7 +1142,6 @@ def lower_parallel_callable(
|
||||
in_axes: Iterable[Optional[int]],
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
@ -1197,8 +1180,7 @@ def lower_parallel_callable(
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size, global_axis_size, devices,
|
||||
in_axes, out_axes_thunk, avals)
|
||||
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
|
||||
pci, fun, global_arg_shapes)
|
||||
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(pci, fun)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("sharded_avals: %s", shards.sharded_avals)
|
||||
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
|
||||
@ -1976,7 +1958,6 @@ def _pmap_dce_rule(used_outputs, eqn):
|
||||
eqn.params['global_axis_size'], None):
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
||||
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
|
||||
# TODO(yashkatariya,mattjj): Handle global_arg_shapes here too.
|
||||
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
|
||||
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
|
||||
new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
|
||||
@ -2095,8 +2076,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
|
||||
def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
axis_size, global_axis_size, devices, name,
|
||||
call_jaxpr, backend=None, in_axes, out_axes,
|
||||
donated_invars, global_arg_shapes,
|
||||
is_explicit_global_axis_size):
|
||||
donated_invars, is_explicit_global_axis_size):
|
||||
del donated_invars # Unused.
|
||||
xla.check_backend_matches(backend, ctx.module_context.platform)
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
|
@ -4272,14 +4272,13 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
|
||||
_identity_fn, None, (), (), sharded_dim, sharded_dim)
|
||||
p = api._prepare_pmap(
|
||||
_identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple,
|
||||
donate_tuple, None, None, None, None, args, kwargs)
|
||||
donate_tuple, None, None, None, args, kwargs)
|
||||
out_flat = pxla.xla_pmap_impl(
|
||||
p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices, in_axes=p.in_axes_flat,
|
||||
out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__,
|
||||
donated_invars=p.donated_invars,
|
||||
global_arg_shapes=p.global_arg_shapes_flat,
|
||||
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
|
||||
)
|
||||
return tree_util.tree_unflatten(p.out_tree(), out_flat)
|
||||
|
@ -2999,7 +2999,6 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
in (c, f, g) }
|
||||
devices=None
|
||||
donated_invars=(False, False, False)
|
||||
global_arg_shapes=(None,)
|
||||
global_axis_size=None
|
||||
in_axes=(0, 0, 0)
|
||||
name=<lambda>
|
||||
|
Loading…
x
Reference in New Issue
Block a user