Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago

PiperOrigin-RevId: 520356179
This commit is contained in:
Yash Katariya 2023-03-29 09:22:34 -07:00 committed by jax authors
parent a964ae7fac
commit fbc05ee5ac
6 changed files with 26 additions and 66 deletions

View File

@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list
* CUDA 11.4 support has been dropped. JAX GPU wheels only support * CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source. from source.
* `global_arg_shapes` argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.
## jaxlib 0.4.8 ## jaxlib 0.4.8

View File

@ -460,7 +460,6 @@ captured using the ``xla_pmap`` primitive. Consider this example
in (k,) } in (k,) }
devices=None devices=None
donated_invars=(False, False) donated_invars=(False, False)
global_arg_shapes=(None,)
global_axis_size=1 global_axis_size=1
in_axes=(None, 0) in_axes=(None, 0)
is_explicit_global_axis_size=False is_explicit_global_axis_size=False

View File

@ -1437,12 +1437,6 @@ def pmap(
For more details on buffer donation see the For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_. `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
the partitioned values span multiple processes. The global cross-process
per-replica shape of each argument, i.e. does not include the leading
pmapped dimension. Can be None for replicated arguments. This API is
likely to change in the future.
Returns: Returns:
A parallelized version of ``fun`` with arguments that correspond to those of A parallelized version of ``fun`` with arguments that correspond to those of
``fun`` but with extra array axes at positions indicated by ``in_axes`` and ``fun`` but with extra array axes at positions indicated by ``in_axes`` and
@ -1565,6 +1559,12 @@ def pmap(
>>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP >>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP
[ 13. 13.] [ 13. 13.]
""" """
if global_arg_shapes is not None:
raise ValueError(
"global_arg_shapes only worked with sharded_jit which has long been"
" removed from JAX. Please migrate to pjit and remove global_arg_shapes"
" from pmap.")
if FLAGS.experimental_cpp_pmap: if FLAGS.experimental_cpp_pmap:
func = _cpp_pmap func = _cpp_pmap
else: else:
@ -1579,8 +1579,7 @@ def pmap(
devices=devices, devices=devices,
backend=backend, backend=backend,
axis_size=axis_size, axis_size=axis_size,
donate_argnums=donate_argnums, donate_argnums=donate_argnums)
global_arg_shapes=global_arg_shapes)
class PmapCallInfo(NamedTuple): class PmapCallInfo(NamedTuple):
@ -1591,7 +1590,6 @@ class PmapCallInfo(NamedTuple):
donated_invars: Sequence[bool] donated_invars: Sequence[bool]
in_axes_flat: Sequence[Optional[int]] in_axes_flat: Sequence[Optional[int]]
local_axis_size: int local_axis_size: int
global_arg_shapes_flat: Sequence[Optional[Tuple[int, ...]]]
out_axes_thunk: HashableFunction out_axes_thunk: HashableFunction
devices: Optional[Sequence[xc.Device]] devices: Optional[Sequence[xc.Device]]
global_axis_size: int global_axis_size: int
@ -1628,7 +1626,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
return global_axis_size return global_axis_size
def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, in_devices, backend_name, donate_tuple, in_devices, backend_name,
axis_size, args, kwargs): axis_size, args, kwargs):
if in_devices is not None and len(in_devices) == 0: if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.") raise ValueError("'devices' argument to pmap must be non-empty, or None.")
@ -1651,15 +1649,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums) dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums)
else: else:
dyn_in_axes = in_axes dyn_in_axes = in_axes
dyn_global_arg_shapes = global_arg_shapes
if isinstance(global_arg_shapes, tuple):
dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums)
else:
dyn_global_arg_shapes = global_arg_shapes
else: else:
dyn_args, dyn_in_axes = args, in_axes dyn_args, dyn_in_axes = args, in_axes
dyn_global_arg_shapes = global_arg_shapes
args, in_tree = tree_flatten((dyn_args, kwargs)) args, in_tree = tree_flatten((dyn_args, kwargs))
if donate_tuple and not config.jax_debug_nans: if donate_tuple and not config.jax_debug_nans:
@ -1667,9 +1658,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
else: else:
donated_invars = (False,) * len(args) donated_invars = (False,) * len(args)
in_axes_flat = tuple(flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0))) in_axes_flat = tuple(flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0)))
global_arg_shapes_flat = tuple(flatten_axes(
"pmap global_arg_shapes", in_tree, (dyn_global_arg_shapes, None),
kws=True))
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap") local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
f, res_paths = result_paths(f) f, res_paths = result_paths(f)
@ -1709,7 +1697,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donated_invars=donated_invars, donated_invars=donated_invars,
in_axes_flat=in_axes_flat, in_axes_flat=in_axes_flat,
local_axis_size=local_axis_size, local_axis_size=local_axis_size,
global_arg_shapes_flat=global_arg_shapes_flat,
out_axes_thunk=out_axes_thunk, out_axes_thunk=out_axes_thunk,
devices=None if in_devices is None else tuple(in_devices), devices=None if in_devices is None else tuple(in_devices),
global_axis_size=global_axis_size, global_axis_size=global_axis_size,
@ -1727,12 +1714,11 @@ def _get_f_mapped(
backend: Optional[str], backend: Optional[str],
axis_size: Optional[int], axis_size: Optional[int],
donate_tuple: Tuple[int, ...], donate_tuple: Tuple[int, ...],
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]],
): ):
def pmap_f(*args, **kwargs): def pmap_f(*args, **kwargs):
p = _prepare_pmap( p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, backend, axis_size, args, kwargs) devices, backend, axis_size, args, kwargs)
for arg in p.flat_args: for arg in p.flat_args:
dispatch.check_arg(arg) dispatch.check_arg(arg)
out = pxla.xla_pmap( out = pxla.xla_pmap(
@ -1741,7 +1727,6 @@ def _get_f_mapped(
devices=p.devices, devices=p.devices,
in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__, donated_invars=p.donated_invars, name=p.flat_fun.__name__, donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size) is_explicit_global_axis_size=p.is_explicit_global_axis_size)
return p.out_tree, out return p.out_tree, out
@ -1780,7 +1765,6 @@ def _python_pmap(
backend: Optional[str] = None, backend: Optional[str] = None,
axis_size: Optional[int] = None, axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (), donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> stages.Wrapped: ) -> stages.Wrapped:
"""The Python only implementation.""" """The Python only implementation."""
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
@ -1799,7 +1783,6 @@ def _python_pmap(
devices=devices, devices=devices,
backend=backend, backend=backend,
axis_size=axis_size, axis_size=axis_size,
global_arg_shapes=global_arg_shapes,
donate_tuple=donate_tuple) donate_tuple=donate_tuple)
out_tree, out_flat = f_pmapped_(*args, **kwargs) out_tree, out_flat = f_pmapped_(*args, **kwargs)
@ -1807,7 +1790,7 @@ def _python_pmap(
pmap_f.lower = _pmap_lower( pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple) backend, axis_size, donate_tuple)
return cast(stages.Wrapped, pmap_f) return cast(stages.Wrapped, pmap_f)
@ -1842,7 +1825,6 @@ def _cpp_pmap(
backend: Optional[str] = None, backend: Optional[str] = None,
axis_size: Optional[int] = None, axis_size: Optional[int] = None,
donate_argnums: Union[int, Iterable[int]] = (), donate_argnums: Union[int, Iterable[int]] = (),
global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> Any: ) -> Any:
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes,
@ -1852,7 +1834,7 @@ def _cpp_pmap(
@api_boundary @api_boundary
def cache_miss(*args, **kwargs): def cache_miss(*args, **kwargs):
p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, global_arg_shapes, devices, backend, donate_tuple, devices, backend,
axis_size, args, kwargs) axis_size, args, kwargs)
for arg in p.flat_args: for arg in p.flat_args:
dispatch.check_arg(arg) dispatch.check_arg(arg)
@ -1867,7 +1849,6 @@ def _cpp_pmap(
out_axes_thunk=p.out_axes_thunk, out_axes_thunk=p.out_axes_thunk,
name=p.flat_fun.__name__, name=p.flat_fun.__name__,
donated_invars=p.donated_invars, donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size, is_explicit_global_axis_size=p.is_explicit_global_axis_size,
) )
@ -1939,13 +1920,13 @@ def _cpp_pmap(
pmap_f.lower = _pmap_lower( pmap_f.lower = _pmap_lower(
fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices,
backend, axis_size, global_arg_shapes, donate_tuple) backend, axis_size, donate_tuple)
return pmap_f return pmap_f
def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
devices, backend, axis_size, global_arg_shapes, donate_tuple): # noqa: F811 devices, backend, axis_size, donate_tuple): # noqa: F811
"""Make a ``lower`` method for pmapped functions.""" """Make a ``lower`` method for pmapped functions."""
# If the function we returned from ``pmap`` were a class instance, # If the function we returned from ``pmap`` were a class instance,
# this might naturally be a method, with ``fun`` as a ``self`` and # this might naturally be a method, with ``fun`` as a ``self`` and
@ -1966,7 +1947,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
""" """
p = _prepare_pmap( p = _prepare_pmap(
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
global_arg_shapes, devices, backend, axis_size, args, kwargs) devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args)) abstract_args = list(map(shaped_abstractify, p.flat_args))
computation = pxla.lower_parallel_callable( computation = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name, p.flat_fun, backend, axis_name,
@ -1976,7 +1957,6 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
in_axes=p.in_axes_flat, in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk, out_axes_thunk=p.out_axes_thunk,
donated_invars=p.donated_invars, donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size, is_explicit_global_axis_size=p.is_explicit_global_axis_size,
avals=abstract_args, avals=abstract_args,
lowering_platform=_experimental_lowering_platform) lowering_platform=_experimental_lowering_platform)

View File

@ -745,25 +745,22 @@ def xla_pmap_impl_lazy(
in_axes: Sequence[Optional[int]], in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]], out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool], donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool, is_explicit_global_axis_size: bool,
) -> Callable: ) -> Callable:
if (config.jax_disable_jit and config.jax_eager_pmap and if (config.jax_disable_jit and config.jax_eager_pmap and
not is_explicit_global_axis_size and not any(d for d in donated_invars) not is_explicit_global_axis_size and not any(d for d in donated_invars)):
and not all(g is not None for g in global_arg_shapes)):
def _emap_apply_fn(*args): def _emap_apply_fn(*args):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
axis_size=axis_size, global_axis_size=global_axis_size, axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=name, in_axes=in_axes, devices=devices, name=name, in_axes=in_axes,
out_axes_thunk=out_axes_thunk, out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars, donated_invars=donated_invars,
global_arg_shapes=global_arg_shapes,
is_explicit_global_axis_size=is_explicit_global_axis_size) is_explicit_global_axis_size=is_explicit_global_axis_size)
return _emap_apply_fn return _emap_apply_fn
abstract_args = unsafe_map(xla.abstractify, args) abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun, fingerprint = parallel_callable( compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name, fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes, in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, *abstract_args) is_explicit_global_axis_size, *abstract_args)
# Don't re-abstractify args unless logging is enabled for performance. # Don't re-abstractify args unless logging is enabled for performance.
@ -793,15 +790,12 @@ def _emap_impl(fun: lu.WrappedFun, *args,
in_axes: Sequence[Optional[int]], in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]], out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool], donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool, is_explicit_global_axis_size: bool,
): ):
from jax._src import array from jax._src import array
# TODO(sharadmv,mattjj): implement these cases # TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars): if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.") raise NotImplementedError("Buffer donation not supported in eager pmap.")
if any(g is not None for g in global_arg_shapes):
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
if is_explicit_global_axis_size: if is_explicit_global_axis_size:
raise NotImplementedError("Non-default global_axis_size not supported in " raise NotImplementedError("Non-default global_axis_size not supported in "
"eager pmap.") "eager pmap.")
@ -1029,12 +1023,11 @@ def parallel_callable(fun: lu.WrappedFun,
in_axes: Sequence[Optional[int]], in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]], out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool], donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool, is_explicit_global_axis_size: bool,
*avals): *avals):
pmap_computation = lower_parallel_callable( pmap_computation = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars, global_arg_shapes, in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals, lowering_platform=None) is_explicit_global_axis_size, avals, lowering_platform=None)
pmap_executable = pmap_computation.compile() pmap_executable = pmap_computation.compile()
return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint])
@ -1091,26 +1084,17 @@ def find_replicas(jaxpr, axis_size, global_axis_size):
def stage_parallel_callable( def stage_parallel_callable(
pci: ParallelCallableInfo, pci: ParallelCallableInfo,
fun: lu.WrappedFun, fun: lu.WrappedFun):
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
sharded_avals = tuple( sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals)) for axis, aval in safe_zip(pci.in_axes, pci.avals))
if any(s is not None for s in global_arg_shapes):
# TODO(skye): we could take this branch unconditionally if we handled
# grad of global_arg_shapes correctly.
global_sharded_avals = [
aval.update(shape=shape) if shape is not None else aval
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
global_sharded_avals = sharded_avals # type: ignore
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for pmap in {elapsed_time} sec", "for pmap in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT): event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap")) fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info) jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
@ -1133,7 +1117,7 @@ def stage_parallel_callable(
num_global_shards = replicas.num_global_replicas * parts.num_partitions num_global_shards = replicas.num_global_replicas * parts.num_partitions
shards = ShardInfo( shards = ShardInfo(
sharded_avals, out_sharded_avals, global_sharded_avals, sharded_avals, out_sharded_avals, sharded_avals,
num_local_shards, num_global_shards) num_local_shards, num_global_shards)
return jaxpr, consts, replicas, parts, shards return jaxpr, consts, replicas, parts, shards
@ -1158,7 +1142,6 @@ def lower_parallel_callable(
in_axes: Iterable[Optional[int]], in_axes: Iterable[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]], out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool], donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]],
is_explicit_global_axis_size: bool, is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue], avals: Sequence[core.AbstractValue],
*, *,
@ -1197,8 +1180,7 @@ def lower_parallel_callable(
pci = ParallelCallableInfo( pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices, name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals) in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, parts, shards = stage_parallel_callable( jaxpr, consts, replicas, parts, shards = stage_parallel_callable(pci, fun)
pci, fun, global_arg_shapes)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals) logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals) logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
@ -1976,7 +1958,6 @@ def _pmap_dce_rule(used_outputs, eqn):
eqn.params['global_axis_size'], None): eqn.params['global_axis_size'], None):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars']) _, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
# TODO(yashkatariya,mattjj): Handle global_arg_shapes here too.
_, in_axes = partition_list(used_inputs, eqn.params['in_axes']) _, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
_, out_axes = partition_list(used_outputs, eqn.params['out_axes']) _, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
new_params = dict(eqn.params, call_jaxpr=new_jaxpr, new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
@ -2095,8 +2076,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
def _pmap_lowering(ctx, *in_nodes, axis_name, def _pmap_lowering(ctx, *in_nodes, axis_name,
axis_size, global_axis_size, devices, name, axis_size, global_axis_size, devices, name,
call_jaxpr, backend=None, in_axes, out_axes, call_jaxpr, backend=None, in_axes, out_axes,
donated_invars, global_arg_shapes, donated_invars, is_explicit_global_axis_size):
is_explicit_global_axis_size):
del donated_invars # Unused. del donated_invars # Unused.
xla.check_backend_matches(backend, ctx.module_context.platform) xla.check_backend_matches(backend, ctx.module_context.platform)
# We in-line here rather than generating a Call HLO as in the xla_call # We in-line here rather than generating a Call HLO as in the xla_call

View File

@ -4272,14 +4272,13 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs):
_identity_fn, None, (), (), sharded_dim, sharded_dim) _identity_fn, None, (), (), sharded_dim, sharded_dim)
p = api._prepare_pmap( p = api._prepare_pmap(
_identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple, _identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple,
donate_tuple, None, None, None, None, args, kwargs) donate_tuple, None, None, None, args, kwargs)
out_flat = pxla.xla_pmap_impl( out_flat = pxla.xla_pmap_impl(
p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name, p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices, in_axes=p.in_axes_flat, devices=p.devices, in_axes=p.in_axes_flat,
out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__, out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__,
donated_invars=p.donated_invars, donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
is_explicit_global_axis_size=p.is_explicit_global_axis_size, is_explicit_global_axis_size=p.is_explicit_global_axis_size,
) )
return tree_util.tree_unflatten(p.out_tree(), out_flat) return tree_util.tree_unflatten(p.out_tree(), out_flat)

View File

@ -2999,7 +2999,6 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
in (c, f, g) } in (c, f, g) }
devices=None devices=None
donated_invars=(False, False, False) donated_invars=(False, False, False)
global_arg_shapes=(None,)
global_axis_size=None global_axis_size=None
in_axes=(0, 0, 0) in_axes=(0, 0, 0)
name=<lambda> name=<lambda>